from tqdm import tqdm import matplotlib.pyplot as plt import numpy as np import torch import wandb def train(args, model, loader, f_loss, optimizer, device, log_interval = 100): """ Train a model for one epoch, iterating over the loader using the f_loss to compute the loss and the optimizer to update the parameters of the model. Arguments : model -- A torch.nn.Module object loader -- A torch.utils.data.DataLoader f_loss -- The loss function, i.e. a loss Module optimizer -- A torch.optim.Optimzer object device -- a torch.device class specifying the device used for computation Returns : """ model.train() gradients = [] out = [] tar = [] for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)): inputs, targets = inputs.to(device), targets.to(device) # Compute the forward pass through the network up to the loss outputs = model(inputs) loss = f_loss(outputs, targets) # Backward and optimize optimizer.zero_grad() loss.backward() #torch.nn.utils.clip_grad_norm(model.parameters(), 50) Y = list(model.parameters())[0].grad.cpu().tolist() #gradients.append(np.mean(Y)) #tar.append(np.mean(outputs.cpu().tolist())) #out.append(np.mean(targets.cpu().tolist())) if not args.no_wandb: if batch_idx % log_interval == 0: wandb.log({"train_loss" : loss}) optimizer.step() #visualize_gradients(gradients) #visualize_gradients(tar) #visualize_gradients(out) def visualize_gradients(gradients): print(gradients) import numpy as np X = np.linspace(0,len(gradients),len(gradients)) plt.scatter(X,gradients) plt.show() if __name__=="__main__": import numpy as np Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]] X = np.linspace(0,len(Y),len(Y)) for i,curve in enumerate(Y): for point in curve : plt.scatter(X[i],point) plt.show()