Skip to content
Snippets Groups Projects
train.py 2.07 KiB
Newer Older
Yandi's avatar
Yandi committed
from tqdm import tqdm
Yandi's avatar
Yandi committed
import matplotlib.pyplot as plt
import numpy as np
Yandi's avatar
Yandi committed
import torch
import wandb
Yandi's avatar
Yandi committed

def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
    """
    Train a model for one epoch, iterating over the loader
    using the f_loss to compute the loss and the optimizer
    to update the parameters of the model.

    Arguments :

        model     -- A torch.nn.Module object
        loader    -- A torch.utils.data.DataLoader
        f_loss    -- The loss function, i.e. a loss Module
        optimizer -- A torch.optim.Optimzer object
        device    -- a torch.device class specifying the device
                     used for computation

    Returns :
    """

    model.train()
Yandi's avatar
Yandi committed
    gradients = []
    out         = []
    tar         = [] 
    for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the forward pass through the network up to the loss
        outputs = model(inputs)
        loss = f_loss(outputs, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
Yandi's avatar
Yandi committed

Yandi's avatar
Yandi committed

        #torch.nn.utils.clip_grad_norm(model.parameters(), 50)
Yandi's avatar
Yandi committed
        Y = list(model.parameters())[0].grad.cpu().tolist()
        #gradients.append(np.mean(Y))
        #tar.append(np.mean(outputs.cpu().tolist()))
        #out.append(np.mean(targets.cpu().tolist()))
Yandi's avatar
Yandi committed
        if not args.no_wandb:
            if batch_idx % log_interval == 0:
                wandb.log({"train_loss" : loss})
Yandi's avatar
Yandi committed
        optimizer.step()
    #visualize_gradients(gradients)
    #visualize_gradients(tar)
    #visualize_gradients(out)
Yandi's avatar
Yandi committed

def visualize_gradients(gradients):
    print(gradients)
    import numpy as np
    X = np.linspace(0,len(gradients),len(gradients))
    plt.scatter(X,gradients)
    plt.show()
Yandi's avatar
Yandi committed
if __name__=="__main__":
    import numpy as np
    Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
    X = np.linspace(0,len(Y),len(Y))
    for i,curve in enumerate(Y):
        for point in curve : 
            plt.scatter(X[i],point)
    plt.show()