From ec5f4bfad2cab493527c524c5c66278765b06bba Mon Sep 17 00:00:00 2001
From: Yandi <yandirzm@gmail.com>
Date: Sun, 22 Jan 2023 12:51:27 +0100
Subject: [PATCH] [Dev] Printing, removing prints

---
 train.py | 44 +++++++++++++++++++++++++++++++-------------
 1 file changed, 31 insertions(+), 13 deletions(-)

diff --git a/train.py b/train.py
index 489959f..ef930e9 100644
--- a/train.py
+++ b/train.py
@@ -1,4 +1,6 @@
 from tqdm import tqdm
+import matplotlib.pyplot as plt
+import numpy as np
 
 def train(model, loader, f_loss, optimizer, device):
     """
@@ -19,7 +21,9 @@ def train(model, loader, f_loss, optimizer, device):
     """
 
     model.train()
-
+    gradients = []
+    out         = []
+    tar         = [] 
     for _, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
         inputs, targets = inputs.to(device), targets.to(device)
 
@@ -27,20 +31,34 @@ def train(model, loader, f_loss, optimizer, device):
         outputs = model(inputs)
         loss = f_loss(outputs, targets)
 
-        print("Loss")
-        print(loss)
-
-        print("outputs")
-        print(outputs)
-
-        print("targets")
-        print(targets)
-
         # Backward and optimize
         optimizer.zero_grad()
         loss.backward()
 
-        print("GRads")
-        print(list(model.parameters())[0].grad)
+        
+        #Y = list(model.parameters())[0].grad.cpu().tolist()
+        
+        #gradients.append(np.mean(Y))
+        #tar.append(np.mean(outputs.cpu().tolist()))
+        #out.append(np.mean(targets.cpu().tolist()))
+        
+        optimizer.step()
+    #visualize_gradients(gradients)
+    #visualize_gradients(tar)
+    #visualize_gradients(out)
+
+def visualize_gradients(gradients):
+    print(gradients)
+    import numpy as np
+    X = np.linspace(0,len(gradients),len(gradients))
+    plt.scatter(X,gradients)
+    plt.show()
 
-        optimizer.step()
\ No newline at end of file
+if __name__=="__main__":
+    import numpy as np
+    Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
+    X = np.linspace(0,len(Y),len(Y))
+    for i,curve in enumerate(Y):
+        for point in curve : 
+            plt.scatter(X[i],point)
+    plt.show()
\ No newline at end of file
-- 
GitLab