main.py 3.02 KiB
#Internal imports
import dataloader
import model
import test
from train import train
import losses
import optimizers
import create_submission
import utils
#External imports
import yaml
import torch
import logging
import torch.optim
import torch.nn as nn
import os
def optimizer(cfg, network):
result = {"Adam" : torch.optim.Adam(network.parameters())}
return result[cfg["Optimizer"]]
if __name__ == "__main__":
logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
config_file = open("config.yml")
cfg = yaml.load(config_file)
use_cuda = torch.cuda.is_available()
trainpath = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
num_days = int(cfg["Dataset"]["num_days"])
batch_size = int(cfg["Dataset"]["batch_size"])
num_workers = int(cfg["Dataset"]["num_workers"])
valid_ratio = float(cfg["Dataset"]["valid_ratio"])
max_num_samples = eval(cfg["Dataset"]["max_num_samples"])
train_loader, valid_loader = dataloader.get_dataloaders(
trainpath,
num_days,
batch_size,
num_workers,
use_cuda,
valid_ratio,
overwrite_index=True,
max_num_samples=max_num_samples,
train_transform=dataloader.transform_remove_space_time(),
valid_transform=dataloader.transform_remove_space_time()
)
if use_cuda :
device = torch.device('cuda')
else :
device = toch.device('cpu')
#network = network.build_network(cfg, 18)
network = nn.Sequential(
nn.Linear(14,8,False),
nn.ReLU(),
nn.Linear(8, 35, True),
nn.ReLU(),
nn.Linear(35,35,True),
nn.ReLU(),
nn.Linear(35,35,True),
nn.ReLU(),
nn.Linear(35,35,True),
nn.ReLU(),
nn.Linear(35,35,True),
nn.ReLU(),
nn.Linear(35,1, True),
nn.ReLU()
)
def init_xavier(module):
if type(module)==nn.Linear:
nn.init.xavier_uniform_(module.weight)
network = network.to(device)
"""
for param in list(network.parameters()):
param = 1
"""
f_loss = losses.RMSLELoss()
optimizer = optimizer(cfg, network)
logdir = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"])
network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
for t in range(cfg["Training"]["Epochs"]):
torch.autograd.set_detect_anomaly(True)
print("Epoch {}".format(t))
train(network, train_loader, f_loss, optimizer, device)
#print(list(network.parameters())[0].grad)
val_loss = test.test(network, valid_loader, f_loss, device)
network_checkpoint.update(val_loss)
print(" Validation : Loss : {:.4f}".format(val_loss))
create_submission.create_submission(network, None)
"""
logdir = generate_unique_logpath(top_logdir, "linear")
print("Logging to {}".format(logdir))
# -> Prints out Logging to ./logs/linear_1
if not os.path.exists(logdir):
os.mkdir(logdir)
"""