Newer
Older
import matplotlib.pyplot as plt
import numpy as np
def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
"""
Train a model for one epoch, iterating over the loader
using the f_loss to compute the loss and the optimizer
to update the parameters of the model.
Arguments :
model -- A torch.nn.Module object
loader -- A torch.utils.data.DataLoader
f_loss -- The loss function, i.e. a loss Module
optimizer -- A torch.optim.Optimzer object
device -- a torch.device class specifying the device
used for computation
Returns :
"""
model.train()
for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
inputs, targets = inputs.to(device), targets.to(device)
# Compute the forward pass through the network up to the loss
outputs = model(inputs)
loss = f_loss(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
#torch.nn.utils.clip_grad_norm(model.parameters(), 50)
#gradients.append(np.mean(Y))
#tar.append(np.mean(outputs.cpu().tolist()))
#out.append(np.mean(targets.cpu().tolist()))
if not args.no_wandb:
if batch_idx % log_interval == 0:
wandb.log({"train_loss" : loss})
#visualize_gradients(gradients)
#visualize_gradients(tar)
#visualize_gradients(out)
def visualize_gradients(gradients):
print(gradients)
import numpy as np
X = np.linspace(0,len(gradients),len(gradients))
plt.scatter(X,gradients)
plt.show()