diff --git a/config.yml b/config.yml index ab9b5b1fddeadf52594ae47c515c81a6bca0eb24..63ca7a391b7e2b5c90f9a3aa59d0ac4951835b2d 100644 --- a/config.yml +++ b/config.yml @@ -21,6 +21,14 @@ ApproximativeMean: "torch.tensor([ 4.2457e+01, 7.4651e+00, 1.6738e+02, 1.3576 ApproximativeSTD: "torch.tensor([5.8939e-01, 8.1625e-01, 1.4535e+02, 5.4952e+07, 1.7543e-02, 1.3846e+02, 2.1302e-01, 1.9558e+00, 4.1455e+00, 1.2408e+01, 2.2938e-02, 9.9070e-02, 1.9490e-01, 9.2847e-03, 2.2575e+00, 8.5310e-02, 7.8280e-02, 8.6237e-02])" +ApproximativeMaxi: "torch.tensor([ 4.3479e+01, 9.0000e+00, 4.9267e+02, 1.4528e+09, 2.4088e+00, + 2.7824e+03, 1.5576e+00, 6.2457e+00, 2.5120e+02, 2.7188e+02, + 8.1683e+00, 3.2447e-01, 3.9041e+01, 2.7162e+00, 2.9419e+01, + 8.6284e-01, 7.6471e-01, -7.7745e-02])" +ApproximativeMini: "torch.tensor([ 4.1479e+01, 6.0000e+00, 1.0182e+00, 1.2623e+09, 2.2433e+00, + 1.0910e+01, 1.0000e-11, 1.0000e-11, -1.1467e+01, 1.9718e+02, + 7.9218e+00, 1.0000e-11, 3.7171e+01, 2.5584e+00, 1.2075e+01, + -1.2436e+00, -9.9256e-01, -8.8131e-01])" #Optimizer selection Optimizer: Adam # in {Adam} diff --git a/dataloader.py b/dataloader.py index a781edd55ed37b0410332611e3cce3f0ce623f9c..d312555015cb363fa1f7b0f04e41d6652cc615d7 100644 --- a/dataloader.py +++ b/dataloader.py @@ -201,6 +201,10 @@ def get_stats_train_dataset( mean = 0. std = 0. nb_samples = 0. + it = iter(train_loader) + X, Y = next(it) + maxi = X.max(dim=0).values.max(dim=0).values + mini = X.min(dim=0).values.min(dim=0).values print("Computing the statistics of the train dataset") @@ -209,6 +213,9 @@ def get_stats_train_dataset( # Update the total sum mean += data[0].mean(dim=0).mean(dim=0) * data[0].size(0) std += data[0].std(dim=0).mean(dim=0) * data[0].size(0) + #Update mini and maxi + maxi = torch.stack((data[0].max(dim=0).values.max(dim=0).values, maxi)).max(dim=0).values + mini = torch.stack((data[0].min(dim=0).values.min(dim=0).values, mini)).min(dim=0).values # Update the number of samples nb_samples += data[0].size(0) @@ -216,7 +223,7 @@ def get_stats_train_dataset( mean /= nb_samples std /= nb_samples - return mean, std + return mean, std, maxi, mini def get_test_dataloader( filepath, @@ -259,6 +266,9 @@ def transform_remove_space_time(): def transform_normalize_with_train_statistics(MEAN, STD): return lambda attributes : torch.div(torch.sub(attributes, MEAN), STD) +def transform_min_max_scaling(MIN, MAX): + return lambda attributes : torch.div(torch.sub(attributes, MIN), torch.sub(MAX,MIN)) + def composite_transform(f,g): return lambda attributes : f(g(attributes)) @@ -292,8 +302,8 @@ if __name__ == "__main__": train_transform =composite_transform(transform_remove_space_time(), transform_normalize_with_train_statistics(MEAN, STD)) ) - """ - #mean, std = get_stats_train_dataset( + + mean, std, maxi, mini = get_stats_train_dataset( filepath = trainpath, num_days = num_days, batch_size = batch_size, @@ -303,9 +313,12 @@ if __name__ == "__main__": overwrite_index=True, max_num_samples=max_num_samples, ) - """ - + + print("Mini = ") + print(mini) + print("Maxi =") + print(maxi) it = iter(train_loader) X, Y = next(it) diff --git a/logs/LinearRegression_10_Kaggle_2/best_model.pt b/logs/LinearRegression_10_Kaggle_2/best_model.pt deleted file mode 100644 index 582f03fe3775a85a568499c448a18ac353cc8941..0000000000000000000000000000000000000000 Binary files a/logs/LinearRegression_10_Kaggle_2/best_model.pt and /dev/null differ diff --git a/main.py b/main.py index 7864a331f8f0e0b81fecb17cce978aa61ea38189..5ff24ec5e944e87d2d88c6a0a37e83d671f04297 100644 --- a/main.py +++ b/main.py @@ -49,8 +49,10 @@ if __name__ == "__main__": if approx_stats: MEAN = eval(cfg["ApproximativeMean"]) STD = eval(cfg["ApproximativeSTD"]) + MAX = eval(cfg["ApproximativeMaxi"]) + MIN = eval(cfg["ApproximativeMini"]) else : - MEAN, STD = get_stats_train_dataset(trainpath, + MEAN, STD, MAX, MIN = get_stats_train_dataset(trainpath, num_days, batch_size, num_workers, @@ -73,8 +75,8 @@ if __name__ == "__main__": max_num_samples=max_num_samples, #train_transform = dataloader.transform_remove_space_time(), #valid_transform = dataloader.transform_remove_space_time() - train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_normalize_with_train_statistics(MEAN, STD)), - valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_normalize_with_train_statistics(MEAN, STD)) + train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), + valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)) ) if use_cuda : @@ -92,8 +94,8 @@ if __name__ == "__main__": optimizer = optimizer(cfg, network) - logdir = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"]) - wandb.run.name = utils.generate_unique_logpath("", cfg["Model"]["Name"]) + logdir, raw_run_name = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"]) + wandb.run.name = raw_run_name network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network) @@ -114,7 +116,7 @@ if __name__ == "__main__": wandb.log({"val_loss": val_loss}) - create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_normalize_with_train_statistics(MEAN, STD)), device) + create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device) """ logdir = generate_unique_logpath(top_logdir, "linear") print("Logging to {}".format(logdir)) diff --git a/utils.py b/utils.py index 0c585e715e1357f627c86f3a759f9c131fa34768..e5b78aeb31791d143e46b8b7473a63b40cbf0751 100644 --- a/utils.py +++ b/utils.py @@ -14,4 +14,4 @@ def create_unique_logpath(top_logdir, raw_run_name): os.mkdir(top_logdir) logdir = generate_unique_logpath(top_logdir, raw_run_name) os.mkdir(logdir) - return logdir \ No newline at end of file + return logdir, raw_run_name \ No newline at end of file