diff --git a/my_train.py b/my_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..778b6c7e356fa2152a9edc82e92b40946d401639
--- /dev/null
+++ b/my_train.py
@@ -0,0 +1,66 @@
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import wandb
+
+def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
+    """
+    Train a model for one epoch, iterating over the loader
+    using the f_loss to compute the loss and the optimizer
+    to update the parameters of the model.
+
+    Arguments :
+
+        model     -- A torch.nn.Module object
+        loader    -- A torch.utils.data.DataLoader
+        f_loss    -- The loss function, i.e. a loss Module
+        optimizer -- A torch.optim.Optimzer object
+        device    -- a torch.device class specifying the device
+                     used for computation
+
+    Returns :
+    """
+
+    model.train()
+    gradients = []
+    out         = []
+    tar         = [] 
+    for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
+        inputs, targets = inputs.to(device), targets.to(device)
+
+        # Compute the forward pass through the network up to the loss
+
+        # target's shape is (B, Num_days, 1)
+        outputs = model(inputs)
+        loss = f_loss(outputs, targets)
+
+        # Backward and optimize
+        optimizer.zero_grad()
+        loss.backward()
+
+        #torch.nn.utils.clip_grad_norm(model.parameters(), 50)
+        
+        Y = list(model.parameters())[0].grad.cpu().tolist()
+        
+
+        if not args.no_wandb:
+            if batch_idx % log_interval == 0:
+                wandb.log({"train_loss" : loss})
+        optimizer.step()
+
+def visualize_gradients(gradients):
+    print(gradients)
+    import numpy as np
+    X = np.linspace(0,len(gradients),len(gradients))
+    plt.scatter(X,gradients)
+    plt.show()
+
+if __name__=="__main__":
+    import numpy as np
+    Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
+    X = np.linspace(0,len(Y),len(Y))
+    for i,curve in enumerate(Y):
+        for point in curve : 
+            plt.scatter(X[i],point)
+    plt.show()
\ No newline at end of file