diff --git a/main.py b/main.py index 76645eb3625939e1d841fe78ea9f6dd200eb754d..d0d75d933c6822113fe926713583097fa8770a71 100644 --- a/main.py +++ b/main.py @@ -121,6 +121,7 @@ def train(args, cfg): if best_val_loss != None: if val_loss < best_val_loss : network_checkpoint.update(val_loss) + best_val_loss = val_loss scheduler.step(val_loss) @@ -207,6 +208,6 @@ if __name__ == "__main__": config_file = open("config.yml") cfg = yaml.load(config_file, Loader=yaml.FullLoader) - + eval(f"{args.command}(args)")