From b9a8d9b7afb568f819568f7976748730463fc1ce Mon Sep 17 00:00:00 2001 From: Yandi <yandirzm@gmail.com> Date: Sat, 21 Jan 2023 22:45:15 +0100 Subject: [PATCH] [First try] --- main.py | 7 ++++++- model.py | 3 +-- optimizers.py | 2 +- train.py | 5 ++++- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index bbc7091..cca6083 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,11 @@ import losses import optimizers import torch import logging +import torch.optim + +def optimizer(cfg, model): + result = {"Adam" : torch.optim.Adam(model.parameters())} + return result[cfg["Optimizer"]] if __name__ == "__main__": logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO) @@ -43,7 +48,7 @@ if __name__ == "__main__": f_loss = losses.RMSLELoss() - optimizer = optimizers.optimizer(cfg, model) + optimizer = optimizer(cfg, model) for t in range(cfg["Training"]["Epochs"]): print("Epoch {}".format(t)) diff --git a/model.py b/model.py index 84716f3..4e6830c 100644 --- a/model.py +++ b/model.py @@ -15,8 +15,7 @@ class LinearRegression(nn.Module): self.activate = nn.ReLU() def forward(self, x): y = self.regressor(x).view((x.shape[0],-1)) - y = self.activate(y) - return y + return self.activate(y) def build_model(cfg, input_size): return eval(f"{cfg['Model']['Name']}(cfg, input_size)") diff --git a/optimizers.py b/optimizers.py index a2d7704..f3ca790 100644 --- a/optimizers.py +++ b/optimizers.py @@ -1,5 +1,5 @@ import torch.optim def optimizer(cfg, model): - result = {"Adam" : torch.optim.Adam(model.parameters())} + result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-2)} return result[cfg["Optimizer"]] diff --git a/train.py b/train.py index 0217b12..11c9e2e 100644 --- a/train.py +++ b/train.py @@ -30,4 +30,7 @@ def train(model, loader, f_loss, optimizer, device): # Backward and optimize optimizer.zero_grad() loss.backward() - optimizer.step() \ No newline at end of file + optimizer.step() + + print(model.regressor.weight) + print(model.regressor.bias) \ No newline at end of file -- GitLab