Skip to content
Snippets Groups Projects
Commit b9a8d9b7 authored by Yandi's avatar Yandi
Browse files

[First try]

parent f69df593
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -7,6 +7,11 @@ import losses
import optimizers
import torch
import logging
import torch.optim
def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters())}
return result[cfg["Optimizer"]]
if __name__ == "__main__":
logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
......@@ -43,7 +48,7 @@ if __name__ == "__main__":
f_loss = losses.RMSLELoss()
optimizer = optimizers.optimizer(cfg, model)
optimizer = optimizer(cfg, model)
for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t))
......
......@@ -15,8 +15,7 @@ class LinearRegression(nn.Module):
self.activate = nn.ReLU()
def forward(self, x):
y = self.regressor(x).view((x.shape[0],-1))
y = self.activate(y)
return y
return self.activate(y)
def build_model(cfg, input_size):
return eval(f"{cfg['Model']['Name']}(cfg, input_size)")
......
import torch.optim
def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters())}
result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-2)}
return result[cfg["Optimizer"]]
......@@ -30,4 +30,7 @@ def train(model, loader, f_loss, optimizer, device):
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
\ No newline at end of file
optimizer.step()
print(model.regressor.weight)
print(model.regressor.bias)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment