From aaf309df85ca565e71bfe7f4ac5108c3f6cdde5c Mon Sep 17 00:00:00 2001 From: Yandi <yandirzm@gmail.com> Date: Sun, 22 Jan 2023 13:02:05 +0100 Subject: [PATCH] [Dev] Added transform argument --- create_submission.py | 2 +- main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/create_submission.py b/create_submission.py index 86f6bd8..45aebaa 100644 --- a/create_submission.py +++ b/create_submission.py @@ -23,7 +23,7 @@ import torch # Local imports import bindataset as dataset -def create_submission(model): +def create_submission(model, transform): step_days = 10 batch_size = 1024 # We make chunks of num_days consecutive samples; As our dummy predictor diff --git a/main.py b/main.py index 1d844c5..a29d8dc 100644 --- a/main.py +++ b/main.py @@ -79,7 +79,7 @@ if __name__ == "__main__": val_loss = test.test(model, valid_loader, f_loss, device) print(" Validation : Loss : {:.4f}".format(val_loss)) - create_submission.create_submission(model) + create_submission.create_submission(model, None) """ logdir = generate_unique_logpath(top_logdir, "linear") print("Logging to {}".format(logdir)) -- GitLab