# coding: utf-8
"""
Datasets for accessing the CMEMS data

1) PointDataset:
    [i]: return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?)
                and the phyc to predict
    __len__ : total number of data samples with non NAN values

For computing [i] we need a map from a linear index to the [ilat, ilon, idepth, itime] for indexing our 4D volumes. This can be possibly a file saved on disk and not loaded in main memory

The map is required because the measures are possibly sparse in space (latitude, longitude, depth, time) and possibly non contiguous. For example, walking along a constant longitude, you may be on coast, then in water then in coast then
in water hence the sparsity.

2) Dataset
    [i]: return timeseries of size (15,)+ time(?) and additional values for depth(?)+ lat(?) + lon(?) and the phyc to predict. Depth, lat and lon are scalars, not timeseries
    __len__ : total number of data samples with non NAN values

We also need a map for converting linear indices to [ilat, ilon, idepth, itime_begin] and the constructor requires a chunk size in time to split the long running time series into smaller chunks

This file supposes the data have been converted from the netCDF4 to a raw
binary format. From our experiments, this speeds up by 450x the time to iterate
with PointDataset over the whole data

"""

# Standard imports
import os
import sys
import logging
import pathlib
from typing import Union
import struct

from datetime import datetime

# import multiprocessing
from multiprocessing import Lock

# from multiprocessing.synchronize import Lock

# External imports
import numpy as np
import tqdm
import torch
import torch.utils.data as data

_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
_DEFAULT_TEST_FILEPATH = (
    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
)

_ENCODING_LINEAR = "I"
_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
_ENCODING_OFFSET_FORMAT = ""
_ENCODING_ENDIAN = "<"


def write_bin_data(fp, fmt, values):
    fp.write(struct.pack(fmt, *values))


def read_bin_data(fp, offset, whence, fmt, lock):
    # Unfortunately does not work
    # If it worked, that would prevent to open/close
    # the file repeatidly
    # with lock:
    with open(fp.name, "rb") as fp:
        # values = struct.unpack_from(fmt, localfp, offset)
        fp.seek(offset, whence)
        nbytes = struct.calcsize(fmt)
        values = struct.unpack(fmt, fp.read(nbytes))
    return values, nbytes


class Dataset(data.Dataset):
    """
    Dataset for training a predictor predicting the phytoplankton
    density from (a series of) environmental variables
    + time + depth + latitude + longitude

    This dataset contains 50 154 984 valid data points
    sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin

    Generating the index takes approximately 20 seconds
    """

    def __init__(
        self,
        filepath: Union[pathlib.Path, str],
        overwrite_index: bool = False,
        train=True,
        subset_file=None,
        num_days=1,
        transform=None,
        target_transform=None,
    ):
        """
        Builds a pointdataset generating the index if necessary or requested

        Arguments:
            filepath: the full path to the nc file to load
            overwrite_index: if True ignores the index and regenerates it
            train: if True, accessing an element also returns the phyc
            subset_file: a filename which holds a list of indices this dataset must use
            num_days: the number of days that each sample considers
            transform: a transform to apply to the input tensor
            target_transform: a transform to apply to the phyc output tensor
        """
        super().__init__()
        if isinstance(filepath, str):
            filepath = pathlib.Path(filepath)
        self.filepath = filepath

        # Load the dataset. It can be indexed by [lat, long, depth, time]
        self.fp = open(filepath, "rb")
        # self.fp.close()
        self.fp_lock = Lock()

        # List all the environmental variables
        # To which we could add time, depth, latitude, longitude
        self.in_variables = [
            "dissic",
            "mlotst",
            "nh4",
            "no3",
            "nppv",
            "o2",
            "ph",
            "po4",
            "so",
            "talk",
            "thetao",
            "uo",
            "vo",
            "zos",
            # "phyc",
        ]
        self.train = train
        if self.train:
            self.out_variable = "phyc"

        # Store the size of the timeseries, in days
        self.num_days = num_days

        # Set up the format of a row
        self.row_format = (
            "?" + "f" * len(self.in_variables) + ("f" if self.train else "")
        )
        self.row_size = struct.calcsize(_ENCODING_ENDIAN + self.row_format)

        # Load the header of the file for the dimension variables
        # Local utilitary function to parse the header
        def _read_dim(fp, offset, base_format, lock):
            fmt = _ENCODING_ENDIAN + "i"
            (dim, nbytes_dim) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)

            dim = dim[0]  # the returned values is a tuple
            offset += nbytes_dim

            fmt = _ENCODING_ENDIAN + (base_format * dim)
            (values, nbytes_values) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
            return dim, np.array(values), nbytes_dim + nbytes_values

        self.header_offset = 0
        self.nlatitudes, self.latitudes, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.nlongitudes, self.longitudes, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.ndepths, self.depths, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.ntimes, self.times, nbytes = _read_dim(
            self.fp, self.header_offset, "i", self.fp_lock
        )
        # If we need to convert timestamps to datetimes
        # self.dtimes = np.array([datetime.fromtimestamp(di) for di in dtimes.tolist()])
        self.header_offset += nbytes

        self.lat_chunk_size = self.nlongitudes * self.ndepths * self.ntimes
        self.lon_chunk_size = self.ndepths * self.ntimes
        self.depth_chunk_size = self.ntimes

        logging.info(
            f"The loaded dataset contains {self.nlatitudes} latitudes, {self.nlongitudes} longitudes, {self.ndepths} depths and {self.ntimes} time points"
        )

        # We store an index which maps a linear index of valid measures
        # (i.e. excluding points where there is no phytoplankton record)
        # to the byte offset in the data file
        self._index_file = None
        self._index_lock = Lock()
        self._load_index(filepath.name, overwrite_index)

        # The subset map is a filehandler to a file
        # where the i-th row contains the i-th index in the original dataset
        self._subset_map = None
        if subset_file is not None:
            self._subset_lock = Lock()
            self._load_subset_index(subset_file)

        self.transform = transform
        self.target_transform = target_transform

    def _generate_index(self, indexpath):
        """
        Generate an index for the binary data file
        which maps the linear index to the original data file offset
        and 4D indices of valid measures
        Note we only index valid measures, i.e. excluding coast, etc..

        Generating the index takes around 2 minutes for 50M samples
        """

        fp = open(self.filepath, "rb")
        with open(indexpath, "wb") as fhindex:
            linear_index = 0
            # Rewind the file handler to just after the header
            fp.seek(self.header_offset, os.SEEK_SET)
            for ilatitude in tqdm.tqdm(range(self.nlatitudes)):
                for ilongitude in range(self.nlongitudes):
                    for idepth in range(self.ndepths):
                        # No need to iterate over time
                        # only latitude x longitude x depth
                        # specify a valid or invalid location

                        t0_offset = (
                            self.header_offset
                            + (
                                ilatitude * self.lat_chunk_size
                                + ilongitude * self.lon_chunk_size
                                + idepth * self.depth_chunk_size
                            )
                            * self.row_size
                        )
                        fp.seek(t0_offset, os.SEEK_SET)
                        # Just read the 'valid' field
                        (is_valid,) = struct.unpack(
                            "<?", fp.read(struct.calcsize("<?"))
                        )

                        # If the location is valid
                        # we record all the time samples
                        if is_valid:
                            for dt in range(
                                0, self.ntimes - self.num_days + 1, self.num_days
                            ):
                                fileoffset = t0_offset + dt * self.row_size

                                fmt = (
                                    _ENCODING_ENDIAN
                                    + (_ENCODING_LINEAR * 2)
                                    + (_ENCODING_INDEX * 4)
                                )
                                write_bin_data(
                                    fhindex,
                                    fmt,
                                    (
                                        linear_index,
                                        fileoffset,
                                        ilatitude,
                                        ilongitude,
                                        idepth,
                                        dt,
                                    ),
                                )

                                linear_index += 1
        fp.close()

    def _load_index(self, basename: str, overwrite_index: bool):
        """
        Loads (and possibly compute) an index file
        to convert a linear index idx to its corresponding
        time, depth, latitude, longitude indices

        Arguments:

        """
        indexpath = pathlib.Path(".") / f"{basename}_index.idx"
        # Generate the index if necessary or requested
        if not indexpath.exists() or overwrite_index:
            logging.info("Generating the index")
            self._generate_index(indexpath)
        # And then load the index
        logging.info(f"Loading the index from {indexpath}")
        self._index_file = open(indexpath, "rb")

    def _load_subset_index(self, subsetpath):
        self._subset_map = open(subsetpath, "rb")

    def _get_fileoffset(self, idx):
        fmt = _ENCODING_ENDIAN + (_ENCODING_LINEAR * 2) + (_ENCODING_INDEX * 4)
        whence = 0 if idx >= 0 else 2
        offset = idx * struct.calcsize(fmt)

        (values, _) = read_bin_data(
            self._index_file, offset, whence, fmt, self._index_lock
        )
        linidx = values[0]
        file_offset = values[1]
        tab_indices = values[2:]
        return linidx, file_offset, tab_indices

    def __getitem__(self, idx):
        """
        Access the i-th measure
        Return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?) and the phyc to predict
        """

        # If we are processing a subet, convert the "idx" to the
        # original dataset index
        if self._subset_map is not None:
            fmt = _ENCODING_ENDIAN + _ENCODING_LINEAR
            offset = idx * struct.calcsize(fmt)
            (values, _) = read_bin_data(
                self._subset_map, offset, os.SEEK_SET, fmt, self._subset_lock
            )
            idx = values[0]

        # File offset for the i-th sample
        _, file_offset, (ilatitude, ilongitude, idepth, itime) = self._get_fileoffset(
            idx
        )
        values, _ = read_bin_data(
            self.fp,
            file_offset,
            os.SEEK_SET,
            _ENCODING_ENDIAN + (self.row_format * self.num_days),
            self.fp_lock,
        )

        values = np.array(values).reshape((self.num_days, -1))
        assert np.all(
            values[:, 0]
        )  # This is always expected to be True since only valid
        # locations can be indexed with our generated index map

        # Preprend the latitude, longitude, depth and time
        # TODO Take the time to think about whether this is relevant !
        coordinates = np.ones((self.num_days, 4))
        # print(coordinates.shape, self.times[itime : (itime + self.num_days)])
        coordinates[:, 0] *= self.latitudes[ilatitude]
        coordinates[:, 1] *= self.longitudes[ilongitude]
        coordinates[:, 2] *= self.depths[idepth]
        coordinates[:, 3] = self.times[itime : (itime + self.num_days)]

        # Stick the values with the coordinates
        # This np array is now (num_days , num_coordinates + num_values)
        values = np.hstack([coordinates, values[:, 1:]])

        if self.train:
            in_values = torch.Tensor(values[:, :-1])
            out_values = torch.Tensor(values[:, -1])

            if self.transform is not None:
                in_values = self.transform(in_values)

            if self.target_transform is not None:
                out_values = self.target_transform(out_values)

            return in_values, out_values
        else:
            in_values = torch.Tensor(values[:, :-1])
            if self.transform is not None:
                in_values = self.transform(in_values)
            return in_values

    def __len__(self):
        """
        Returns the total number of valid datapoints for this dataset

        This corresponds to the length of the index or subset index
        """
        # If we are processing a subset, the size of the dataset
        # is the size of the subset
        if self._subset_map is not None:
            subset_size = os.path.getsize(self._subset_map.name)
            return subset_size // struct.calcsize(_ENCODING_ENDIAN + _ENCODING_LINEAR)
        else:
            # Access the last of the index file
            # and get its linear index. This linear index is also
            # the total number of "rows" minus 1 in the file
            lin_idx, _, _ = self._get_fileoffset(-1)
            return lin_idx + 1


def train_valid_split(
    dataset,
    valid_ratio,
    train_subset_filepath,
    valid_subset_filepath,
    max_num_samples=None,
):
    """
    Generates two splits of 1-valid_ratio and valid_ratio fractions
    Takes 2 minutes for generating the splits for 50M datapoints

    For smalls sets, we could use indices in main memory with
    torch.utils.data.Subset
    """

    N = len(dataset)
    if max_num_samples is not None:
        pkeep = max_num_samples / N
        Nvalid = int(valid_ratio * max_num_samples)/1000000
        Ntrain = max_num_samples - Nvalid
        ptrain = Ntrain / max_num_samples

    else:
        pkeep = 1.0
        Nvalid = int(valid_ratio * N)
        Ntrain = N - Nvalid
        ptrain = Ntrain / N

    ftrain = open(train_subset_filepath, "wb")
    fvalid = open(valid_subset_filepath, "wb")

    gen = np.random.default_rng()

    logging.info(f"Generating the subset files from {N} samples")
    for i in tqdm.tqdm(range(N)):
        if gen.uniform() < pkeep:
            if gen.uniform() < ptrain:
                ftrain.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
            else:
                fvalid.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
    fvalid.close()
    ftrain.close()


def get_dataloaders(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    valid_ratio,
    overwrite_index=True,
    train_transform=None,
    train_target_transform=None,
    valid_transform=None,
    valid_target_transform=None,
    max_num_samples=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    dataset = Dataset(
        filepath, train=True, overwrite_index=overwrite_index, num_days=num_days
    )
    logging.info(f"  - Loaded a dataset with {len(dataset)} samples")

    # Split the number of samples in each fold
    logging.info("  - Splitting the data in training and validation sets")
    train_subset_file = "train_indices.subset"
    valid_subset_file = "valid_indices.subset"
    train_valid_split(
        dataset, valid_ratio, train_subset_file, valid_subset_file, max_num_samples
    )

    logging.info("  - Subset dataset")
    train_dataset = Dataset(
        filepath,
        subset_file=train_subset_file,
        transform=train_transform,
        target_transform=train_target_transform,
        num_days=num_days,
    )
    valid_dataset = Dataset(
        filepath,
        subset_file=valid_subset_file,
        transform=valid_transform,
        target_transform=valid_target_transform,
        num_days=num_days,
    )
    # The sum of the two folds are not expected to be exactly of
    # max_num_samples
    logging.info(f"  - The train fold has {len(train_dataset)} samples")
    logging.info(f"  - The valid fold has {len(valid_dataset)} samples")

    # Build the dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, valid_loader


def get_test_dataloader(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    overwrite_index=True,
    transform=None,
    target_transform=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    test_dataset = Dataset(
        filepath,
        train=False,
        transform=transform,
        target_transform=target_transform,
        num_days=num_days,
        overwrite_index=overwrite_index,
    )
    logging.info(f"I loaded {len(test_dataset)} values in the test set")

    # Build the dataloader, be carefull not to shuffle the data
    # to ensure we keep the
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    return test_loader


def test_point_dataset():
    dataset = Dataset(_DEFAULT_TRAIN_FILEPATH, train=True, overwrite_index=True)
    logging.info(f"Dataset loaded with {len(dataset)} samples")

    xi, yi = dataset[103]
    print(f"Loaded one sample with shapes of xi({xi.shape}), yi({yi.shape})")

    # Test iterating over the whole dataset
    # This takes around 1 hour to iterate over the 50M samples
    # logging.info("Iterating over the whole dataset, sample by sample")
    # for i in tqdm.tqdm(range(len(dataset))):
    #     xi, yi = dataset[i]


def test_getitem():
    dataset = Dataset(_DEFAULT_TRAIN_FILEPATH, train=True, overwrite_index=True)
    logging.info(f"Dataset loaded with {len(dataset)} samples")
    idx = 113728

    lin_index, file_offset, tab_indices = dataset._get_fileoffset(idx)
    logging.info(
        f"The idx {idx} corresponds to : \n\tlinear index={lin_index}\n\tfile offset={file_offset}\n\ttab indices={tab_indices}"
    )
    train_loader, valid_loader = get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )

def test_dataloader():
    logging.info("====> Test dataloader")
    use_cuda = torch.cuda.is_available()
    trainpath = _DEFAULT_TRAIN_FILEPATH
    num_days = 1  # Test with sequence of 1 day
    batch_size = 128
    num_workers = 7
    valid_ratio = 0.2
    # max_num_samples = 1000
    max_num_samples = None

    train_loader, valid_loader = get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )

    it = iter(train_loader)
    X, Y = next(it)
    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")

    # This takes around 20 minutes for 50M samples
    minis = []
    maxis = []
    logging.info("Iterating over the whole data loader")
    for X, Y in tqdm.tqdm(train_loader):
        minis.append(Y.min())
        maxis.append(Y.max())
    logging.info(f"phyc ranges in [{min(minis)},{max(maxis)}]")


def test_time_dataset():
    import matplotlib.pyplot as plt

    use_cuda = torch.cuda.is_available()
    trainpath = _DEFAULT_TRAIN_FILEPATH
    num_days = 35

    batch_size = 24
    num_workers = 7
    valid_ratio = 0.2
    max_num_samples = None

    train_dataset = Dataset(
        trainpath,
        overwrite_index=True,
        train=True,
        num_days=num_days,
    )
    logging.info(f"The training set contains {len(train_dataset)} samples")

    # Test indexing one element
    X, y = train_dataset[166]
    logging.info(f"Got one sample of shape {X.shape}, {y.shape}")

    # Iterate over the whole training set, around 2 minutes
    # logging.info("Iterating over the whole dataset, sample by sample")
    # for i in tqdm.tqdm(range(len(train_dataset))):
    #     X, y = train_dataset[i]

    # Create the dataloader for getting minibatches
    train_loader, valid_loader = get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )

    it = iter(train_loader)
    X, Y = next(it)
    logging.info(f"Got one minibatch of shape {X.shape}, {Y.shape}")

    plt.figure()
    plt.plot(Y.numpy().T)
    plt.show()

    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")
    logging.info(
        "The tensors are in order (B, T, N) for the input and (B, T) for the output. Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
    )
    logging.info("Iterating over the whole train loader for testing the time it takes")
    # Around 1 minute
    for X, Y in tqdm.tqdm(train_loader):
        continue


def test_time_test_dataset():
    import matplotlib.pyplot as plt

    use_cuda = torch.cuda.is_available()
    testpath = _DEFAULT_TEST_FILEPATH
    num_days = 365

    batch_size = 24
    num_workers = 7

    test_dataset = Dataset(
        testpath,
        overwrite_index=True,
        train=False,
        num_days=num_days,
    )
    logging.info(f"The test set contains {len(test_dataset)} samples")

    # Test indexing one element
    X = test_dataset[166]
    logging.info(f"Got one sample of shape {X.shape}")

    # Iterate over the whole training set, around 2 minutes
    # logging.info("Iterating over the whole dataset, sample by sample")
    # for i in tqdm.tqdm(range(len(test_dataset))):
    #     X = test_dataset[i]

    # Create the dataloader for getting minibatches
    test_loader = get_test_dataloader(
        testpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        overwrite_index=True,
    )

    it = iter(test_loader)
    X = next(it)
    logging.info(f"Got one minibatch of shape {X.shape}")

    plt.figure()
    plt.plot(X[:, :, 6].T)
    plt.title("NH4")
    plt.show()

    logging.info(f"Got a minibatch of size {X.shape}")
    logging.info(
        "The tensors are in order (B, T, N). Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
    )
    logging.info("Iterating over the whole train loader for testing the time it takes")
    # Around 1 minute
    for X in tqdm.tqdm(test_loader):
        continue


def test_transforms():
    import matplotlib.pyplot as plt

    logging.info("====> Test transforms dataloader")
    use_cuda = torch.cuda.is_available()
    trainpath = _DEFAULT_TRAIN_FILEPATH
    num_days = 834  # Test with sequence of 1 day
    batch_size = 128
    num_workers = 7
    valid_ratio = 0.2
    max_num_samples = 1000

    def train_transform(X):
        """
        Transform to be applied to an input sample X of shape (T, N)
        """
        # The variables in X are
        #     latitude, longitude, depth, time,
        # followed by
        # dissic, mlotst, nh4, no3, nppv, o2, ph, po4, so, talk, thetao, uo
        # vo, zos

        # As a matter of example, we drop some input variables
        # For example, we drop the latitude and longitude
        return X[:, 2:]

    def valid_transform(X):
        return X

    def target_transform(Y):
        return Y.log()

    train_loader, _ = get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_target_transform=target_transform,
        train_transform=train_transform,
        valid_target_transform=target_transform,
        valid_transform=valid_transform,
    )

    it = iter(train_loader)
    X, Y = next(it)
    logging.info(f"Got one minibatch of shape {X.shape}, {Y.shape}")
    logging.info(
        "The tensors are in order (B, T, N) for the input and (B, T) for the output. Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
    )

    plt.figure()

    plt.plot(list(map(datetime.fromtimestamp, X[0, :, 1].tolist())), Y[0, :])
    plt.xticks(rotation=70)
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")

    logging.info("\n" + "=" * 80)
    test_point_dataset()

    logging.info("\n" + "=" * 80)
    test_getitem()

    logging.info("\n" + "=" * 80)
    test_dataloader()

    logging.info("\n" + "=" * 80)
    test_time_dataset()

    logging.info("\n" + "=" * 80)
    test_time_test_dataset()

    logging.info("\n" + "=" * 80)
    test_transforms()