Skip to content
Snippets Groups Projects
Commit 4403597b authored by Yandi's avatar Yandi
Browse files

lstm with scheduler tested

parent 66dc9e2e
No related branches found
No related tags found
1 merge request!1Master into main
...@@ -43,7 +43,7 @@ Optimizer: Adam # in {Adam} ...@@ -43,7 +43,7 @@ Optimizer: Adam # in {Adam}
#Training parameters #Training parameters
Training: Training:
Epochs: 20 Epochs: 60
#Model selection #Model selection
Model: Model:
......
No preview for this file type
...@@ -114,6 +114,14 @@ if __name__ == "__main__": ...@@ -114,6 +114,14 @@ if __name__ == "__main__":
optimizer = optimizer(cfg, network) optimizer = optimizer(cfg, network)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
'min',
patience = 5,
threshold = 0.2,
factor = 0.5
)
logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"]) logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network) network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
...@@ -121,14 +129,16 @@ if __name__ == "__main__": ...@@ -121,14 +129,16 @@ if __name__ == "__main__":
wandb.run.name = raw_run_name wandb.run.name = raw_run_name
wandb.watch(network, log_freq = log_freq) wandb.watch(network, log_freq = log_freq)
torch.autograd.set_detect_anomaly(True) #torch.autograd.set_detect_anomaly(True)
for t in range(cfg["Training"]["Epochs"]): for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t)) print("Epoch {}".format(t))
train(args, network, train_loader, f_loss, optimizer, device, log_interval) train(args, network, train_loader, f_loss, optimizer, device, log_interval)
val_loss = test.test(network, valid_loader, f_loss, device) val_loss = test.test(network, valid_loader, f_loss, device)
scheduler.step(val_loss)
network_checkpoint.update(val_loss) network_checkpoint.update(val_loss)
print(" Validation : Loss : {:.4f}".format(val_loss)) print(" Validation : Loss : {:.4f}".format(val_loss))
......
...@@ -39,7 +39,6 @@ def train(args, model, loader, f_loss, optimizer, device, log_interval = 100): ...@@ -39,7 +39,6 @@ def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
#torch.nn.utils.clip_grad_norm(model.parameters(), 50) #torch.nn.utils.clip_grad_norm(model.parameters(), 50)
Y = list(model.parameters())[0].grad.cpu().tolist() Y = list(model.parameters())[0].grad.cpu().tolist()
......
No preview for this file type
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment