from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb

def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
    """
    Train a model for one epoch, iterating over the loader
    using the f_loss to compute the loss and the optimizer
    to update the parameters of the model.

    Arguments :

        model     -- A torch.nn.Module object
        loader    -- A torch.utils.data.DataLoader
        f_loss    -- The loss function, i.e. a loss Module
        optimizer -- A torch.optim.Optimzer object
        device    -- a torch.device class specifying the device
                     used for computation

    Returns :
    """

    model.train()
    gradients = []
    out         = []
    tar         = [] 
    for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the forward pass through the network up to the loss
        outputs = model(inputs)
        loss = f_loss(outputs, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()


        #torch.nn.utils.clip_grad_norm(model.parameters(), 50)
        
        Y = list(model.parameters())[0].grad.cpu().tolist()
        
        #gradients.append(np.mean(Y))
        #tar.append(np.mean(outputs.cpu().tolist()))
        #out.append(np.mean(targets.cpu().tolist()))
        if not args.no_wandb:
            if batch_idx % log_interval == 0:
                wandb.log({"train_loss" : loss})
        optimizer.step()

    #visualize_gradients(gradients)
    #visualize_gradients(tar)
    #visualize_gradients(out)

def visualize_gradients(gradients):
    print(gradients)
    import numpy as np
    X = np.linspace(0,len(gradients),len(gradients))
    plt.scatter(X,gradients)
    plt.show()

if __name__=="__main__":
    import numpy as np
    Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
    X = np.linspace(0,len(Y),len(Y))
    for i,curve in enumerate(Y):
        for point in curve : 
            plt.scatter(X[i],point)
    plt.show()