Skip to content
Snippets Groups Projects
Commit b9a8d9b7 authored by Yandi's avatar Yandi
Browse files

[First try]

parent f69df593
No related branches found
No related tags found
1 merge request!1Master into main
...@@ -7,6 +7,11 @@ import losses ...@@ -7,6 +7,11 @@ import losses
import optimizers import optimizers
import torch import torch
import logging import logging
import torch.optim
def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters())}
return result[cfg["Optimizer"]]
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO) logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
...@@ -43,7 +48,7 @@ if __name__ == "__main__": ...@@ -43,7 +48,7 @@ if __name__ == "__main__":
f_loss = losses.RMSLELoss() f_loss = losses.RMSLELoss()
optimizer = optimizers.optimizer(cfg, model) optimizer = optimizer(cfg, model)
for t in range(cfg["Training"]["Epochs"]): for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t)) print("Epoch {}".format(t))
......
...@@ -15,8 +15,7 @@ class LinearRegression(nn.Module): ...@@ -15,8 +15,7 @@ class LinearRegression(nn.Module):
self.activate = nn.ReLU() self.activate = nn.ReLU()
def forward(self, x): def forward(self, x):
y = self.regressor(x).view((x.shape[0],-1)) y = self.regressor(x).view((x.shape[0],-1))
y = self.activate(y) return self.activate(y)
return y
def build_model(cfg, input_size): def build_model(cfg, input_size):
return eval(f"{cfg['Model']['Name']}(cfg, input_size)") return eval(f"{cfg['Model']['Name']}(cfg, input_size)")
......
import torch.optim import torch.optim
def optimizer(cfg, model): def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters())} result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-2)}
return result[cfg["Optimizer"]] return result[cfg["Optimizer"]]
...@@ -30,4 +30,7 @@ def train(model, loader, f_loss, optimizer, device): ...@@ -30,4 +30,7 @@ def train(model, loader, f_loss, optimizer, device):
# Backward and optimize # Backward and optimize
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
\ No newline at end of file
print(model.regressor.weight)
print(model.regressor.bias)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment