From b9a8d9b7afb568f819568f7976748730463fc1ce Mon Sep 17 00:00:00 2001
From: Yandi <yandirzm@gmail.com>
Date: Sat, 21 Jan 2023 22:45:15 +0100
Subject: [PATCH] [First try]

---
 main.py       | 7 ++++++-
 model.py      | 3 +--
 optimizers.py | 2 +-
 train.py      | 5 ++++-
 4 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/main.py b/main.py
index bbc7091..cca6083 100644
--- a/main.py
+++ b/main.py
@@ -7,6 +7,11 @@ import losses
 import optimizers
 import torch
 import logging
+import torch.optim
+
+def optimizer(cfg, model):
+    result = {"Adam" : torch.optim.Adam(model.parameters())}
+    return result[cfg["Optimizer"]]
 
 if __name__ == "__main__":
     logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
@@ -43,7 +48,7 @@ if __name__ == "__main__":
 
     f_loss = losses.RMSLELoss()
 
-    optimizer = optimizers.optimizer(cfg, model)
+    optimizer = optimizer(cfg, model)
   
     for t in range(cfg["Training"]["Epochs"]):
         print("Epoch {}".format(t))
diff --git a/model.py b/model.py
index 84716f3..4e6830c 100644
--- a/model.py
+++ b/model.py
@@ -15,8 +15,7 @@ class LinearRegression(nn.Module):
         self.activate = nn.ReLU()
     def forward(self, x):
         y = self.regressor(x).view((x.shape[0],-1))
-        y = self.activate(y)
-        return y
+        return self.activate(y)
 
 def build_model(cfg, input_size):    
     return eval(f"{cfg['Model']['Name']}(cfg, input_size)")
diff --git a/optimizers.py b/optimizers.py
index a2d7704..f3ca790 100644
--- a/optimizers.py
+++ b/optimizers.py
@@ -1,5 +1,5 @@
 import torch.optim
 
 def optimizer(cfg, model):
-    result = {"Adam" : torch.optim.Adam(model.parameters())}
+    result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-2)}
     return result[cfg["Optimizer"]]
diff --git a/train.py b/train.py
index 0217b12..11c9e2e 100644
--- a/train.py
+++ b/train.py
@@ -30,4 +30,7 @@ def train(model, loader, f_loss, optimizer, device):
         # Backward and optimize
         optimizer.zero_grad()
         loss.backward()
-        optimizer.step()
\ No newline at end of file
+        optimizer.step()
+
+        print(model.regressor.weight)
+        print(model.regressor.bias)
\ No newline at end of file
-- 
GitLab