#Internal imports
import dataloader
import model
import test
from train import train
import losses
import optimizers
import create_submission
import utils

#External imports
import yaml
import torch
import wandb
import logging
import torch.optim
import torch.nn as nn
import os
import argparse

def optimizer(cfg, network):
    result = {"Adam" : torch.optim.Adam(network.parameters())}
    return result[cfg["Optimizer"]]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
    "--no_wandb",
    action="store_true",
    help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
    )

    parser.add_argument(
    "--rootDir",
    default=None,
    help="Directory in which the log files will be stored"
    )

    args = parser.parse_args()

    config_file = open("config.yml")
    cfg = yaml.load(config_file)

    rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]

    logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
    

    use_cuda = torch.cuda.is_available()
    trainpath           = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
    num_days            = int(cfg["Dataset"]["num_days"])
    batch_size          = int(cfg["Dataset"]["batch_size"])
    num_workers         = int(cfg["Dataset"]["num_workers"])
    valid_ratio         = float(cfg["Dataset"]["valid_ratio"])
    max_num_samples     = eval(cfg["Dataset"]["max_num_samples"])

    epochs          = int(cfg["Training"]["Epochs"])
    log_freq        = int(cfg["Wandb"]["log_freq"])
    log_interval    = int(cfg["Wandb"]["log_interval"])

    if not args.no_wandb:
        wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})

    # Re-compute the statistics or use the stored ones
    approx_stats = cfg["ApproximativeStats"]


    if approx_stats:
        MEAN    = eval(cfg["ApproximativeMean"])
        STD     = eval(cfg["ApproximativeSTD"])
        MAX    = eval(cfg["ApproximativeMaxi"])
        MIN    = eval(cfg["ApproximativeMini"])
    else :
        MEAN, STD, MAX, MIN = dataloader.get_stats_train_dataset(trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform=None,
        valid_transform=None
    )

    train_loader, valid_loader = dataloader.get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index =  True,
        max_num_samples=max_num_samples,
        train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)),
        valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
    )

    if use_cuda :
        device = torch.device('cuda')
    else :
        device = toch.device('cpu')

    network = model.build_model(cfg, 14)

    network = network.to(device)

    model.initialize_model(cfg, network)

    f_loss = losses.RMSLELoss()

    optimizer = optimizer(cfg, network)

    logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
    network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)

    if not args.no_wandb:
        wandb.run.name = raw_run_name
        wandb.watch(network, log_freq = log_freq)

    for t in range(cfg["Training"]["Epochs"]):
        print("Epoch {}".format(t))
        train(args, network, train_loader, f_loss, optimizer, device, log_interval)

        val_loss = test.test(network, valid_loader, f_loss, device)

        network_checkpoint.update(val_loss)

        print(" Validation : Loss : {:.4f}".format(val_loss))
        if not args.no_wandb:
            wandb.log({"val_loss": val_loss})


    create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir)