Skip to content
Snippets Groups Projects
main.py 3.02 KiB
#Internal imports
import dataloader
import model
import test
from train import train
import losses
import optimizers
import create_submission
import utils

#External imports
import yaml
import torch
import logging
import torch.optim
import torch.nn as nn
import os

def optimizer(cfg, network):
    result = {"Adam" : torch.optim.Adam(network.parameters())}
    return result[cfg["Optimizer"]]

if __name__ == "__main__":
    logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)

    config_file = open("config.yml")
    cfg = yaml.load(config_file)

    use_cuda = torch.cuda.is_available()
    trainpath           = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
    num_days            = int(cfg["Dataset"]["num_days"])
    batch_size          = int(cfg["Dataset"]["batch_size"])
    num_workers         = int(cfg["Dataset"]["num_workers"])
    valid_ratio         = float(cfg["Dataset"]["valid_ratio"])
    max_num_samples     = eval(cfg["Dataset"]["max_num_samples"])

    train_loader, valid_loader = dataloader.get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform=dataloader.transform_remove_space_time(),
        valid_transform=dataloader.transform_remove_space_time()
    )

    if use_cuda :
        device = torch.device('cuda')
    else :
        device = toch.device('cpu')

    #network = network.build_network(cfg, 18)

    network = nn.Sequential(
        nn.Linear(14,8,False),
        nn.ReLU(),
        nn.Linear(8, 35, True),
        nn.ReLU(),
        nn.Linear(35,35,True),
        nn.ReLU(),
        nn.Linear(35,35,True),
        nn.ReLU(),
        nn.Linear(35,35,True),
        nn.ReLU(),
        nn.Linear(35,35,True),
        nn.ReLU(),
        nn.Linear(35,1, True),
        nn.ReLU()
    )

    def init_xavier(module):
        if type(module)==nn.Linear:
            nn.init.xavier_uniform_(module.weight)
    network = network.to(device)

    """
    for param in list(network.parameters()):
        param = 1
    """

    f_loss = losses.RMSLELoss()

    optimizer = optimizer(cfg, network)

    logdir = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"])
    network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)

    for t in range(cfg["Training"]["Epochs"]):
        torch.autograd.set_detect_anomaly(True)
        print("Epoch {}".format(t))
        train(network, train_loader, f_loss, optimizer, device)


        #print(list(network.parameters())[0].grad)
        val_loss = test.test(network, valid_loader, f_loss, device)

        network_checkpoint.update(val_loss)

        print(" Validation : Loss : {:.4f}".format(val_loss))


    create_submission.create_submission(network, None)
    """
    logdir = generate_unique_logpath(top_logdir, "linear")
    print("Logging to {}".format(logdir))
    # -> Prints out     Logging to   ./logs/linear_1
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    """