diff --git a/logs/main_unit_test.log b/logs/main_unit_test.log index deee767d885c526c245ad202326f1777abe2581c..6feec89f290557e58d7a58b53b59f22c91620b18 100644 --- a/logs/main_unit_test.log +++ b/logs/main_unit_test.log @@ -1607,3 +1607,18 @@ INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx INFO:root: - The train fold has 542071 samples INFO:root: - The valid fold has 135089 samples +INFO:root:= Dataloaders +INFO:root: - Dataset creation +INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points +INFO:root:Generating the index +INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx +INFO:root: - Loaded a dataset with 677160 samples +INFO:root: - Splitting the data in training and validation sets +INFO:root:Generating the subset files from 677160 samples +INFO:root: - Subset dataset +INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points +INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx +INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points +INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx +INFO:root: - The train fold has 541670 samples +INFO:root: - The valid fold has 135490 samples diff --git a/main.py b/main.py index b511a718e4f609d02444ceacb5302e5470c094ae..e1089f91e017886fdd5ff0bbf2818037e24597d4 100644 --- a/main.py +++ b/main.py @@ -120,16 +120,17 @@ if __name__ == "__main__": wandb.watch(network, log_freq = log_freq) for t in range(cfg["Training"]["Epochs"]): - print("Epoch {}".format(t)) + logging.info("Epoch {}".format(t)) train(args, network, train_loader, f_loss, optimizer, device, log_interval) val_loss = test.test(network, valid_loader, f_loss, device) network_checkpoint.update(val_loss) - print(" Validation : Loss : {:.4f}".format(val_loss)) + logging.info(" Validation : Loss : {:.4f}".format(val_loss)) if not args.no_wandb: wandb.log({"val_loss": val_loss}) - + utils.write_summary(logdir, network, optimizer, val_loss) + create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir, logdir) diff --git a/train.py b/train.py index 335fb5bd325cfcef3094b1aa6c663319a679c43a..ed49b27b5a363e847b535b39f29c43c8900eb0a3 100644 --- a/train.py +++ b/train.py @@ -42,18 +42,12 @@ def train(args, model, loader, f_loss, optimizer, device, log_interval = 100): Y = list(model.parameters())[0].grad.cpu().tolist() - #gradients.append(np.mean(Y)) - #tar.append(np.mean(outputs.cpu().tolist())) - #out.append(np.mean(targets.cpu().tolist())) + if not args.no_wandb: if batch_idx % log_interval == 0: wandb.log({"train_loss" : loss}) optimizer.step() - #visualize_gradients(gradients) - #visualize_gradients(tar) - #visualize_gradients(out) - def visualize_gradients(gradients): print(gradients) import numpy as np diff --git a/utils.py b/utils.py index e5b78aeb31791d143e46b8b7473a63b40cbf0751..38b9089d4b31f8be656ad2bf4f28d94e2fe13428 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import os +import sys def generate_unique_logpath(logdir, raw_run_name): i = 0 @@ -14,4 +15,35 @@ def create_unique_logpath(top_logdir, raw_run_name): os.mkdir(top_logdir) logdir = generate_unique_logpath(top_logdir, raw_run_name) os.mkdir(logdir) - return logdir, raw_run_name \ No newline at end of file + return logdir, raw_run_name + +def write_summary(logdir, model, optimizer, val_loss): + summary_file = open(logdir + "/summary.txt", 'w') + summary_text = """ + Validation loss + =============== + {} + + + Executed command + ================ + {} + + Dataset + ======= + FashionMNIST + + Model summary + ============= + {} + + {} trainable parameters + + Optimizer + ======== + {} + + + """.format(val_loss," ".join(sys.argv), model, sum(p.numel() for p in model.parameters() if p.requires_grad), optimizer) + summary_file.write(summary_text) + summary_file.close() \ No newline at end of file