Skip to content
Snippets Groups Projects
Commit ec5f4bfa authored by Yandi's avatar Yandi
Browse files

[Dev] Printing, removing prints

parent 3c86644c
No related branches found
No related tags found
1 merge request!1Master into main
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
def train(model, loader, f_loss, optimizer, device):
"""
......@@ -19,7 +21,9 @@ def train(model, loader, f_loss, optimizer, device):
"""
model.train()
gradients = []
out = []
tar = []
for _, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
inputs, targets = inputs.to(device), targets.to(device)
......@@ -27,20 +31,34 @@ def train(model, loader, f_loss, optimizer, device):
outputs = model(inputs)
loss = f_loss(outputs, targets)
print("Loss")
print(loss)
print("outputs")
print(outputs)
print("targets")
print(targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
print("GRads")
print(list(model.parameters())[0].grad)
#Y = list(model.parameters())[0].grad.cpu().tolist()
#gradients.append(np.mean(Y))
#tar.append(np.mean(outputs.cpu().tolist()))
#out.append(np.mean(targets.cpu().tolist()))
optimizer.step()
#visualize_gradients(gradients)
#visualize_gradients(tar)
#visualize_gradients(out)
def visualize_gradients(gradients):
print(gradients)
import numpy as np
X = np.linspace(0,len(gradients),len(gradients))
plt.scatter(X,gradients)
plt.show()
optimizer.step()
\ No newline at end of file
if __name__=="__main__":
import numpy as np
Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
X = np.linspace(0,len(Y),len(Y))
for i,curve in enumerate(Y):
for point in curve :
plt.scatter(X[i],point)
plt.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment