#Internal imports
import dataloader
import model
import my_test
import my_train
import utils

#External imports
import yaml
import torch
import wandb
import logging
import torch.optim
import torch.nn as nn
import os
import argparse

def choose_optimizer(cfg, network):
    result = {"Adam" : torch.optim.Adam(network.parameters())}
    return result[cfg["Optimizer"]]

def train(args, cfg):
    rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]

    logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
    

    use_cuda = torch.cuda.is_available()
    trainpath           = args.PATHTOTRAININGSET if args.PATHTOTRAININGSET != None else cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
    num_days            = int(cfg["Dataset"]["num_days"])
    batch_size          = int(cfg["Dataset"]["batch_size"])
    num_workers         = int(cfg["Dataset"]["num_workers"])
    valid_ratio         = float(cfg["Dataset"]["valid_ratio"])
    max_num_samples     = eval(cfg["Dataset"]["max_num_samples"])

    epochs          = int(cfg["Training"]["Epochs"])
    log_freq        = int(cfg["Wandb"]["log_freq"])
    log_interval    = int(cfg["Wandb"]["log_interval"])

    dataset_transform = cfg["Dataset"]["Transform"]
    input_size = 14 if "space_time" in dataset_transform else (17 if "time" in dataset_transform else 18)

    if not args.no_log:
        wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})

    # Re-compute the statistics or use the stored ones
    approx_stats = cfg["ApproximativeStats"]


    if approx_stats:
        MEAN    = eval(cfg["ApproximativeMean"])
        STD     = eval(cfg["ApproximativeSTD"])
        MAX    = eval(cfg["ApproximativeMaxi"])
        MIN    = eval(cfg["ApproximativeMini"])
    else :
        MEAN, STD, MAX, MIN = dataloader.get_stats_train_dataset(cfg,trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform=None,
        valid_transform=None
    )

    train_loader, valid_loader = dataloader.get_dataloaders(
        cfg,
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index =  True,
        max_num_samples=max_num_samples,
        train_transform= eval(dataset_transform),
        valid_transform=eval(dataset_transform)
    )

    if use_cuda :
        device = torch.device('cuda')
    else :
        device = toch.device('cpu')

    network = model.build_model(cfg, input_size)

    network = network.to(device)

    model.initialize_model(cfg, network)

    f_loss = model.RMSLELoss()

    optimizer = choose_optimizer(cfg, network)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        'min', 
        patience = 5, 
        threshold = 0.2,
        factor = 0.5
    )

    experiment_name = args.experimentName
    if not(args.no_log):
        logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"] + experiment_name)
        network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
        wandb.run.name = raw_run_name
        wandb.watch(network, log_freq = log_freq)

    if args.detect_anomaly:
        torch.autograd.set_detect_anomaly(True)

    for t in range(cfg["Training"]["Epochs"]):
        print(f"Epoch {t+1}")
        my_train.train(args, network, train_loader, f_loss, optimizer, device, log_interval)

        val_loss = my_test.test(network, valid_loader, f_loss, device)


        network_checkpoint.update(val_loss)

        scheduler.step(val_loss)


        print("Validation : Loss : {:.4f}".format(val_loss))
        if not args.no_log:
            wandb.log({"val_loss": val_loss})

    if not(args.no_log):
        utils.write_summary(logdir, network, optimizer, val_loss)

    logging.info(f"Best model saved in folder {logdir}")

    return logdir


def test(args, cfg):

    dataset_transform = cfg["Dataset"]["Transform"]
    #rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]

    use_cuda = torch.cuda.is_available()
    if use_cuda :
        device = torch.device('cuda')
    else :
        device = toch.device('cpu')

    #logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])

    model_path = f"{args.PATHTOCHECKPOINT}best_model.pt"

    # Re-compute the statistics or use the stored ones
    approx_stats = cfg["ApproximativeStats"]


    if approx_stats:
        MEAN    = eval(cfg["ApproximativeMean"])
        STD     = eval(cfg["ApproximativeSTD"])
        MAX    = eval(cfg["ApproximativeMaxi"])
        MIN    = eval(cfg["ApproximativeMini"])
    else :
        MEAN, STD, MAX, MIN = dataloader.get_stats_train_dataset(cfg,trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform=None,
        valid_transform=None
    )
    dataset_transform = cfg["Dataset"]["Transform"]


    input_size = 14 if "space_time" in dataset_transform else (17 if "time" in dataset_transform else 18)

    network = model.build_model(cfg, input_size)

    network = network.to(device)

    network.load_state_dict(torch.load(model_path))

    utils.create_submission(args, cfg, network, eval(dataset_transform), device, args.PATHTOCHECKPOINT)

    logging.info(f"The submission csv file has been created in the folder : {args.PATHTOCHECKPOINT}")

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument(
    "--detect_anomaly",
    action="store_true",
    help="If specified, torch.autograd.set_detect_anomaly(True) will be activated",
    )

    parser.add_argument(
    "--no_log",
    action="store_true",
    help="If specified, no folder will be created while training the model and no log will be sent to wantdb.",
    )

    parser.add_argument(
    "--rootDir",
    default=None,
    help="Directory in which the log files will be stored"
    )

    parser.add_argument(
    "--experimentName",
    default="",
    help="Name of the experiment, will affect the name of the run on wandb and the name of the created folder"
    )

    parser.add_argument(
        "--PATHTOTESTSET",
    default=None,
    help="Path of the file on which the model will be tested on"
    )

    parser.add_argument(
        "--PATHTOTRAININGSET",
        default=None,
        help="Path of the file on which the model will be trained on"
    )

    parser.add_argument(
        "--PATHTOCHECKPOINT",
        default="./logs/BestBidirectionalLSTM/",
        help="Path of the directory containing the model to load (with the final /)"
    )

    parser.add_argument(
        "command", 
        choices=["train", "test"]
    )

    args = parser.parse_args()
    
    config_file = open("config.yml")
    cfg = yaml.load(config_file, Loader=yaml.FullLoader)

    eval(f"{args.command}(args, cfg)")