Skip to content
Snippets Groups Projects
Commit 4d54dce2 authored by Yandi's avatar Yandi
Browse files

cleaning codebase by moving create_submission to utils

parent f1369e13
Branches
No related tags found
No related merge requests found
# Dataset Configuration
Dataset:
num_days: 73 # Number of days in each sequence - should be the same as in Test -
batch_size: 64
batch_size: 128
num_workers: 7
valid_ratio: 0.2
max_num_samples: None #1000
......@@ -61,7 +61,7 @@ BidirectionalLSTM:
HiddenSize: 16
NumLayers: 2
LSTMDropout: 0
FFNDropout: 0.2
FFNDropout: 0
NumFFN: 3
Initialization: init_he
......
No preview for this file type
#Internal imports
import dataloader
import model
import test
import my_test
import my_train
import create_submission
import utils
#External imports
......@@ -16,7 +15,7 @@ import torch.nn as nn
import os
import argparse
def optimizer(cfg, network):
def choose_optimizer(cfg, network):
result = {"Adam" : torch.optim.Adam(network.parameters())}
return result[cfg["Optimizer"]]
......@@ -91,7 +90,7 @@ def train(args, cfg):
f_loss = model.RMSLELoss()
optimizer = optimizer(cfg, network)
optimizer = choose_optimizer(cfg, network)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
......@@ -113,10 +112,10 @@ def train(args, cfg):
best_val_loss = None
for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t))
print(f"Epoch {t+1}")
my_train.train(args, network, train_loader, f_loss, optimizer, device, log_interval)
val_loss = test.test(network, valid_loader, f_loss, device)
val_loss = my_test.test(network, valid_loader, f_loss, device)
if best_val_loss != None:
if val_loss < best_val_loss :
......@@ -126,7 +125,7 @@ def train(args, cfg):
scheduler.step(val_loss)
print(" Validation : Loss : {:.4f}".format(val_loss))
print("Validation : Loss : {:.4f}".format(val_loss))
if not args.no_wandb:
wandb.log({"val_loss": val_loss})
......@@ -155,7 +154,7 @@ def test(args):
network.load_state_dict(torch.load(model_path))
create_submission.create_submission(args, network, eval(dataset_transform), device, rootDir, logdir)
utils.create_submission(args, network, eval(dataset_transform), device, rootDir, logdir)
logging.info(f"The submission csv file has been created in the folder : {logdir}")
......@@ -209,5 +208,5 @@ if __name__ == "__main__":
config_file = open("config.yml")
cfg = yaml.load(config_file, Loader=yaml.FullLoader)
eval(f"{args.command}(args)")
eval(f"{args.command}(args, cfg)")
......@@ -80,7 +80,14 @@ class BidirectionalLSTM(nn.Module):
self.FFN_dropout = cfg["BidirectionalLSTM"]["FFNDropout"]
self.num_ffn = cfg["BidirectionalLSTM"]["NumFFN"]
self.lstm = nn.LSTM(input_size, self.hidden_size, self.num_layers, batch_first = True, bidirectional =True, dropout = self.LSTM_dropout)
self.lstm = nn.LSTM(
input_size,
self.hidden_size,
self.num_layers,
batch_first = True,
bidirectional =True,
dropout = self.LSTM_dropout)
self.fc = nn.Sequential()
for layer in range(self.num_ffn):
......@@ -105,7 +112,8 @@ class BidirectionalLSTM(nn.Module):
if use_cuda :
device = torch.device('cuda')
else :
device = toch.device('cpu')
device = torch.device('cpu')
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
......@@ -114,6 +122,7 @@ class BidirectionalLSTM(nn.Module):
result = self.fc(out)
result = nn.ReLU()(result)
return result
......@@ -122,7 +131,7 @@ class CNN1D(torch.nn.Module):
def __init__(self, cfg, num_inputs):
super(CNN1D, self).__init__()
self.model = torch.nn.Sequential(
self.block = torch.nn.Sequential(
*conv_block(num_inputs, 32),
*conv_block(32, 128)
)
......@@ -137,9 +146,8 @@ class CNN1D(torch.nn.Module):
def forward(self, x):
x = torch.transpose(x, 1, 2)
out = self.model(x)
out = self.block(x)
print(f"This is after CNN : {out}")
out = self.avg_pool(out)
out = out.view([out.shape[0], -1])
......
File moved
No preview for this file type
import os
# Standard imports
import sys
import logging
import os
import datetime
# External imports
import tqdm
import torch
import torch.nn as nn
import argparse
import yaml
def generate_unique_logpath(logdir, raw_run_name):
i = 0
......@@ -46,4 +56,6 @@ def write_summary(logdir, model, optimizer, val_loss):
""".format(val_loss," ".join(sys.argv), model, sum(p.numel() for p in model.parameters() if p.requires_grad), optimizer)
summary_file.write(summary_text)
summary_file.close()
\ No newline at end of file
summary_file.close()
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment