Skip to content
Snippets Groups Projects
main.py 4.08 KiB
Newer Older
Yandi's avatar
Yandi committed
#Internal imports
import dataloader
import model
import test
from train import train
import losses
Yandi's avatar
Yandi committed
import optimizers
Yandi's avatar
Yandi committed
import create_submission
import utils

#External imports
import yaml
Yandi's avatar
Yandi committed
import torch
import wandb
Yandi's avatar
Yandi committed
import logging
Yandi's avatar
Yandi committed
import torch.optim
Yandi's avatar
Yandi committed
import torch.nn as nn
Yandi's avatar
Yandi committed
import os
import argparse
Yandi's avatar
Yandi committed

Yandi's avatar
Yandi committed
def optimizer(cfg, network):
    result = {"Adam" : torch.optim.Adam(network.parameters())}
Yandi's avatar
Yandi committed
    return result[cfg["Optimizer"]]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
Yandi's avatar
Yandi committed

    parser.add_argument(
    "--no_wandb",
    action="store_true",
    help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
    )

    parser.add_argument(
    "--rootDir",
    default=None,
    help="Directory in which the log files will be stored"
    )

    args = parser.parse_args()

    config_file = open("config.yml")
    cfg = yaml.load(config_file)
    rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]

    logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
    

    use_cuda = torch.cuda.is_available()
    trainpath           = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
Yandi's avatar
Yandi committed
    num_days            = int(cfg["Dataset"]["num_days"])
    batch_size          = int(cfg["Dataset"]["batch_size"])
    num_workers         = int(cfg["Dataset"]["num_workers"])
    valid_ratio         = float(cfg["Dataset"]["valid_ratio"])
    max_num_samples     = eval(cfg["Dataset"]["max_num_samples"])
    epochs          = int(cfg["Training"]["Epochs"])
    log_freq        = int(cfg["Wandb"]["log_freq"])
    log_interval    = int(cfg["Wandb"]["log_interval"])

    if not args.no_wandb:
        wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
    # Re-compute the statistics or use the stored ones
    approx_stats = cfg["ApproximativeStats"]


    if approx_stats:
        MEAN    = eval(cfg["ApproximativeMean"])
        STD     = eval(cfg["ApproximativeSTD"])
        MAX    = eval(cfg["ApproximativeMaxi"])
        MIN    = eval(cfg["ApproximativeMini"])
        MEAN, STD, MAX, MIN = dataloader.get_stats_train_dataset(trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform=None,
        valid_transform=None
    )

    train_loader, valid_loader = dataloader.get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index =  True,
Yandi's avatar
Yandi committed
        max_num_samples=max_num_samples,
        train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)),
        valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
    if use_cuda :
        device = torch.device('cuda')
    else :
        device = toch.device('cpu')

    network = model.build_model(cfg, 14)
Yandi's avatar
Yandi committed

Yandi's avatar
Yandi committed
    network = network.to(device)

    model.initialize_model(cfg, network)
Yandi's avatar
Yandi committed
    f_loss = losses.RMSLELoss()
Yandi's avatar
Yandi committed
    optimizer = optimizer(cfg, network)

    logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
Yandi's avatar
Yandi committed
    network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
    if not args.no_wandb:
        wandb.run.name = raw_run_name
        wandb.watch(network, log_freq = log_freq)
Yandi's avatar
Yandi committed
    for t in range(cfg["Training"]["Epochs"]):
        print("Epoch {}".format(t))
        train(args, network, train_loader, f_loss, optimizer, device, log_interval)
Yandi's avatar
Yandi committed
        val_loss = test.test(network, valid_loader, f_loss, device)

        network_checkpoint.update(val_loss)

Yandi's avatar
Yandi committed
        print(" Validation : Loss : {:.4f}".format(val_loss))
        if not args.no_wandb:
            wandb.log({"val_loss": val_loss})
Yandi's avatar
Yandi committed

    create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir)