Skip to content
Snippets Groups Projects
Commit 4403597b authored by Yandi's avatar Yandi
Browse files

lstm with scheduler tested

parent 66dc9e2e
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -43,7 +43,7 @@ Optimizer: Adam # in {Adam}
#Training parameters
Training:
Epochs: 20
Epochs: 60
#Model selection
Model:
......
No preview for this file type
......@@ -114,6 +114,14 @@ if __name__ == "__main__":
optimizer = optimizer(cfg, network)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
'min',
patience = 5,
threshold = 0.2,
factor = 0.5
)
logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
......@@ -121,14 +129,16 @@ if __name__ == "__main__":
wandb.run.name = raw_run_name
wandb.watch(network, log_freq = log_freq)
torch.autograd.set_detect_anomaly(True)
#torch.autograd.set_detect_anomaly(True)
for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t))
train(args, network, train_loader, f_loss, optimizer, device, log_interval)
val_loss = test.test(network, valid_loader, f_loss, device)
scheduler.step(val_loss)
network_checkpoint.update(val_loss)
print(" Validation : Loss : {:.4f}".format(val_loss))
......
......@@ -39,7 +39,6 @@ def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
optimizer.zero_grad()
loss.backward()
#torch.nn.utils.clip_grad_norm(model.parameters(), 50)
Y = list(model.parameters())[0].grad.cpu().tolist()
......
No preview for this file type
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment