diff --git a/main.py b/main.py index 5ff24ec5e944e87d2d88c6a0a37e83d671f04297..abb6709a9c95724f1a0943ce256a1408eb16e0ba 100644 --- a/main.py +++ b/main.py @@ -16,14 +16,33 @@ import logging import torch.optim import torch.nn as nn import os +import argparse def optimizer(cfg, network): result = {"Adam" : torch.optim.Adam(network.parameters())} return result[cfg["Optimizer"]] if __name__ == "__main__": - logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO) + parser = argparse.ArgumentParser() + parser.add_argument( + "--no_wandb", + action="store_true", + help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.", + ) + + parser.add_argument( + "--rootDir", + help="Directory in which the log files will be stored" + ) + + args = parser.parse_args() + + + rootDir = cfg["LogDir"] if eval(args.rootDir) != None else args.rootDir + + logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO) + config_file = open("config.yml") cfg = yaml.load(config_file) @@ -39,12 +58,12 @@ if __name__ == "__main__": log_freq = int(cfg["Wandb"]["log_freq"]) log_interval = int(cfg["Wandb"]["log_interval"]) + if not args.no_wandb: + wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs}) - wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs}) - + # Re-compute the statistics or use the stored ones approx_stats = cfg["ApproximativeStats"] - print(approx_stats) if approx_stats: MEAN = eval(cfg["ApproximativeMean"]) @@ -73,8 +92,6 @@ if __name__ == "__main__": valid_ratio, overwrite_index = True, max_num_samples=max_num_samples, - #train_transform = dataloader.transform_remove_space_time(), - #valid_transform = dataloader.transform_remove_space_time() train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)) ) @@ -94,33 +111,24 @@ if __name__ == "__main__": optimizer = optimizer(cfg, network) - logdir, raw_run_name = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"]) - wandb.run.name = raw_run_name + logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"]) network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network) - - wandb.watch(network, log_freq = log_freq) + if not args.no_wandb: + wandb.run.name = raw_run_name + wandb.watch(network, log_freq = log_freq) for t in range(cfg["Training"]["Epochs"]): - torch.autograd.set_detect_anomaly(True) print("Epoch {}".format(t)) - train(network, train_loader, f_loss, optimizer, device, log_interval) - + train(args, network, train_loader, f_loss, optimizer, device, log_interval) - #print(list(network.parameters())[0].grad) val_loss = test.test(network, valid_loader, f_loss, device) network_checkpoint.update(val_loss) print(" Validation : Loss : {:.4f}".format(val_loss)) - wandb.log({"val_loss": val_loss}) + if not args.no_wandb: + wandb.log({"val_loss": val_loss}) create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device) - """ - logdir = generate_unique_logpath(top_logdir, "linear") - print("Logging to {}".format(logdir)) - # -> Prints out Logging to ./logs/linear_1 - if not os.path.exists(logdir): - os.mkdir(logdir) - """ \ No newline at end of file