From ac79a9d9f5163d99e04af68255c7bfa5e7dceb7b Mon Sep 17 00:00:00 2001 From: Yandi <yandirzm@gmail.com> Date: Mon, 6 Feb 2023 20:44:41 +0100 Subject: [PATCH] pushin my_train script --- my_train.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 my_train.py diff --git a/my_train.py b/my_train.py new file mode 100644 index 0000000..778b6c7 --- /dev/null +++ b/my_train.py @@ -0,0 +1,66 @@ +from tqdm import tqdm +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb + +def train(args, model, loader, f_loss, optimizer, device, log_interval = 100): + """ + Train a model for one epoch, iterating over the loader + using the f_loss to compute the loss and the optimizer + to update the parameters of the model. + + Arguments : + + model -- A torch.nn.Module object + loader -- A torch.utils.data.DataLoader + f_loss -- The loss function, i.e. a loss Module + optimizer -- A torch.optim.Optimzer object + device -- a torch.device class specifying the device + used for computation + + Returns : + """ + + model.train() + gradients = [] + out = [] + tar = [] + for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)): + inputs, targets = inputs.to(device), targets.to(device) + + # Compute the forward pass through the network up to the loss + + # target's shape is (B, Num_days, 1) + outputs = model(inputs) + loss = f_loss(outputs, targets) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + + #torch.nn.utils.clip_grad_norm(model.parameters(), 50) + + Y = list(model.parameters())[0].grad.cpu().tolist() + + + if not args.no_wandb: + if batch_idx % log_interval == 0: + wandb.log({"train_loss" : loss}) + optimizer.step() + +def visualize_gradients(gradients): + print(gradients) + import numpy as np + X = np.linspace(0,len(gradients),len(gradients)) + plt.scatter(X,gradients) + plt.show() + +if __name__=="__main__": + import numpy as np + Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]] + X = np.linspace(0,len(Y),len(Y)) + for i,curve in enumerate(Y): + for point in curve : + plt.scatter(X[i],point) + plt.show() \ No newline at end of file -- GitLab