Skip to content
Snippets Groups Projects
Commit d2721875 authored by Yandi's avatar Yandi
Browse files

[RNN_FFN]

parent d13daed6
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -30,7 +30,7 @@ def dummy_model(X):
# Divided by a magic number
return X[:, :, 4:].mean(dim=2) / 26 # This is (B, T)
def create_submission(model, transform, device, rootDir):
def create_submission(model, transform, device, rootDir, logdir):
step_days = 10
batch_size = 1024
# We make chunks of num_days consecutive samples; As our dummy predictor
......@@ -57,7 +57,7 @@ def create_submission(model, transform, device, rootDir):
num_days_test = test_loader.dataset.ntimes
logging.info("= Filling in the submission file")
with open(rootDir + "submission.csv", "w") as fh_submission:
with open(logdir + "submission.csv", "w") as fh_submission:
fh_submission.write("Id,Predicted\n")
submission_offset = 0
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -132,4 +132,4 @@ if __name__ == "__main__":
wandb.log({"val_loss": val_loss})
create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir)
create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir, logdir)
......@@ -56,11 +56,11 @@ class RNN(nn.Module):
self.fc.add_module(
f"linear_{layer}", nn.Linear(self.hidden_size, self.hidden_size)
)
self.ffn.add_module(
self.fc.add_module(
f"relu_{layer}",
nn.ReLU()
)
self.ffn.add_module(
self.fc.add_module(
f"dropout_{layer}",
nn.Dropout(p=0.2)
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment