diff --git a/bindataset.py b/bindataset.py
deleted file mode 100644
index 8b4126f24a964bc6137e51b02432d0c16c14a7ac..0000000000000000000000000000000000000000
--- a/bindataset.py
+++ /dev/null
@@ -1,811 +0,0 @@
-# coding: utf-8
-"""
-Datasets for accessing the CMEMS data
-
-1) PointDataset:
-    [i]: return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?)
-                and the phyc to predict
-    __len__ : total number of data samples with non NAN values
-
-For computing [i] we need a map from a linear index to the [ilat, ilon, idepth, itime] for indexing our 4D volumes. This can be possibly a file saved on disk and not loaded in main memory
-
-The map is required because the measures are possibly sparse in space (latitude, longitude, depth, time) and possibly non contiguous. For example, walking along a constant longitude, you may be on coast, then in water then in coast then
-in water hence the sparsity.
-
-2) Dataset
-    [i]: return timeseries of size (15,)+ time(?) and additional values for depth(?)+ lat(?) + lon(?) and the phyc to predict. Depth, lat and lon are scalars, not timeseries
-    __len__ : total number of data samples with non NAN values
-
-We also need a map for converting linear indices to [ilat, ilon, idepth, itime_begin] and the constructor requires a chunk size in time to split the long running time series into smaller chunks
-
-This file supposes the data have been converted from the netCDF4 to a raw
-binary format. From our experiments, this speeds up by 450x the time to iterate
-with PointDataset over the whole data
-
-"""
-
-# Standard imports
-import os
-import sys
-import logging
-import pathlib
-from typing import Union
-import struct
-
-from datetime import datetime
-
-# import multiprocessing
-from multiprocessing import Lock
-
-# from multiprocessing.synchronize import Lock
-
-# External imports
-import numpy as np
-import tqdm
-import torch
-import torch.utils.data as data
-
-_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
-_DEFAULT_TEST_FILEPATH = (
-    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
-)
-
-_ENCODING_LINEAR = "I"
-_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
-_ENCODING_OFFSET_FORMAT = ""
-_ENCODING_ENDIAN = "<"
-
-
-def write_bin_data(fp, fmt, values):
-    fp.write(struct.pack(fmt, *values))
-
-
-def read_bin_data(fp, offset, whence, fmt, lock):
-    # Unfortunately does not work
-    # If it worked, that would prevent to open/close
-    # the file repeatidly
-    # with lock:
-    with open(fp.name, "rb") as fp:
-        # values = struct.unpack_from(fmt, localfp, offset)
-        fp.seek(offset, whence)
-        nbytes = struct.calcsize(fmt)
-        values = struct.unpack(fmt, fp.read(nbytes))
-    return values, nbytes
-
-
-class Dataset(data.Dataset):
-    """
-    Dataset for training a predictor predicting the phytoplankton
-    density from (a series of) environmental variables
-    + time + depth + latitude + longitude
-
-    This dataset contains 50 154 984 valid data points
-    sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin
-
-    Generating the index takes approximately 20 seconds
-    """
-
-    def __init__(
-        self,
-        filepath: Union[pathlib.Path, str],
-        overwrite_index: bool = False,
-        train=True,
-        subset_file=None,
-        num_days=1,
-        transform=None,
-        target_transform=None,
-    ):
-        """
-        Builds a pointdataset generating the index if necessary or requested
-
-        Arguments:
-            filepath: the full path to the nc file to load
-            overwrite_index: if True ignores the index and regenerates it
-            train: if True, accessing an element also returns the phyc
-            subset_file: a filename which holds a list of indices this dataset must use
-            num_days: the number of days that each sample considers
-            transform: a transform to apply to the input tensor
-            target_transform: a transform to apply to the phyc output tensor
-        """
-        super().__init__()
-        if isinstance(filepath, str):
-            filepath = pathlib.Path(filepath)
-        self.filepath = filepath
-
-        # Load the dataset. It can be indexed by [lat, long, depth, time]
-        self.fp = open(filepath, "rb")
-        # self.fp.close()
-        self.fp_lock = Lock()
-
-        # List all the environmental variables
-        # To which we could add time, depth, latitude, longitude
-        self.in_variables = [
-            "dissic",
-            "mlotst",
-            "nh4",
-            "no3",
-            "nppv",
-            "o2",
-            "ph",
-            "po4",
-            "so",
-            "talk",
-            "thetao",
-            "uo",
-            "vo",
-            "zos",
-            # "phyc",
-        ]
-        self.train = train
-        if self.train:
-            self.out_variable = "phyc"
-
-        # Store the size of the timeseries, in days
-        self.num_days = num_days
-
-        # Set up the format of a row
-        self.row_format = (
-            "?" + "f" * len(self.in_variables) + ("f" if self.train else "")
-        )
-        self.row_size = struct.calcsize(_ENCODING_ENDIAN + self.row_format)
-
-        # Load the header of the file for the dimension variables
-        # Local utilitary function to parse the header
-        def _read_dim(fp, offset, base_format, lock):
-            fmt = _ENCODING_ENDIAN + "i"
-            (dim, nbytes_dim) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
-
-            dim = dim[0]  # the returned values is a tuple
-            offset += nbytes_dim
-
-            fmt = _ENCODING_ENDIAN + (base_format * dim)
-            (values, nbytes_values) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
-            return dim, np.array(values), nbytes_dim + nbytes_values
-
-        self.header_offset = 0
-        self.nlatitudes, self.latitudes, nbytes = _read_dim(
-            self.fp, self.header_offset, "f", self.fp_lock
-        )
-        self.header_offset += nbytes
-
-        self.nlongitudes, self.longitudes, nbytes = _read_dim(
-            self.fp, self.header_offset, "f", self.fp_lock
-        )
-        self.header_offset += nbytes
-
-        self.ndepths, self.depths, nbytes = _read_dim(
-            self.fp, self.header_offset, "f", self.fp_lock
-        )
-        self.header_offset += nbytes
-
-        self.ntimes, self.times, nbytes = _read_dim(
-            self.fp, self.header_offset, "i", self.fp_lock
-        )
-        # If we need to convert timestamps to datetimes
-        # self.dtimes = np.array([datetime.fromtimestamp(di) for di in dtimes.tolist()])
-        self.header_offset += nbytes
-
-        self.lat_chunk_size = self.nlongitudes * self.ndepths * self.ntimes
-        self.lon_chunk_size = self.ndepths * self.ntimes
-        self.depth_chunk_size = self.ntimes
-
-        logging.info(
-            f"The loaded dataset contains {self.nlatitudes} latitudes, {self.nlongitudes} longitudes, {self.ndepths} depths and {self.ntimes} time points"
-        )
-
-        # We store an index which maps a linear index of valid measures
-        # (i.e. excluding points where there is no phytoplankton record)
-        # to the byte offset in the data file
-        self._index_file = None
-        self._index_lock = Lock()
-        self._load_index(filepath.name, overwrite_index)
-
-        # The subset map is a filehandler to a file
-        # where the i-th row contains the i-th index in the original dataset
-        self._subset_map = None
-        if subset_file is not None:
-            self._subset_lock = Lock()
-            self._load_subset_index(subset_file)
-
-        self.transform = transform
-        self.target_transform = target_transform
-
-    def _generate_index(self, indexpath):
-        """
-        Generate an index for the binary data file
-        which maps the linear index to the original data file offset
-        and 4D indices of valid measures
-        Note we only index valid measures, i.e. excluding coast, etc..
-
-        Generating the index takes around 2 minutes for 50M samples
-        """
-
-        fp = open(self.filepath, "rb")
-        with open(indexpath, "wb") as fhindex:
-            linear_index = 0
-            # Rewind the file handler to just after the header
-            fp.seek(self.header_offset, os.SEEK_SET)
-            for ilatitude in tqdm.tqdm(range(self.nlatitudes)):
-                for ilongitude in range(self.nlongitudes):
-                    for idepth in range(self.ndepths):
-                        # No need to iterate over time
-                        # only latitude x longitude x depth
-                        # specify a valid or invalid location
-
-                        t0_offset = (
-                            self.header_offset
-                            + (
-                                ilatitude * self.lat_chunk_size
-                                + ilongitude * self.lon_chunk_size
-                                + idepth * self.depth_chunk_size
-                            )
-                            * self.row_size
-                        )
-                        fp.seek(t0_offset, os.SEEK_SET)
-                        # Just read the 'valid' field
-                        (is_valid,) = struct.unpack(
-                            "<?", fp.read(struct.calcsize("<?"))
-                        )
-
-                        # If the location is valid
-                        # we record all the time samples
-                        if is_valid:
-                            for dt in range(
-                                0, self.ntimes - self.num_days + 1, self.num_days
-                            ):
-                                fileoffset = t0_offset + dt * self.row_size
-
-                                fmt = (
-                                    _ENCODING_ENDIAN
-                                    + (_ENCODING_LINEAR * 2)
-                                    + (_ENCODING_INDEX * 4)
-                                )
-                                write_bin_data(
-                                    fhindex,
-                                    fmt,
-                                    (
-                                        linear_index,
-                                        fileoffset,
-                                        ilatitude,
-                                        ilongitude,
-                                        idepth,
-                                        dt,
-                                    ),
-                                )
-
-                                linear_index += 1
-        fp.close()
-
-    def _load_index(self, basename: str, overwrite_index: bool):
-        """
-        Loads (and possibly compute) an index file
-        to convert a linear index idx to its corresponding
-        time, depth, latitude, longitude indices
-
-        Arguments:
-
-        """
-        indexpath = pathlib.Path(".") / f"{basename}_index.idx"
-        # Generate the index if necessary or requested
-        if not indexpath.exists() or overwrite_index:
-            logging.info("Generating the index")
-            self._generate_index(indexpath)
-        # And then load the index
-        logging.info(f"Loading the index from {indexpath}")
-        self._index_file = open(indexpath, "rb")
-
-    def _load_subset_index(self, subsetpath):
-        self._subset_map = open(subsetpath, "rb")
-
-    def _get_fileoffset(self, idx):
-        fmt = _ENCODING_ENDIAN + (_ENCODING_LINEAR * 2) + (_ENCODING_INDEX * 4)
-        whence = 0 if idx >= 0 else 2
-        offset = idx * struct.calcsize(fmt)
-
-        (values, _) = read_bin_data(
-            self._index_file, offset, whence, fmt, self._index_lock
-        )
-        linidx = values[0]
-        file_offset = values[1]
-        tab_indices = values[2:]
-        return linidx, file_offset, tab_indices
-
-    def __getitem__(self, idx):
-        """
-        Access the i-th measure
-        Return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?) and the phyc to predict
-        """
-
-        # If we are processing a subet, convert the "idx" to the
-        # original dataset index
-        if self._subset_map is not None:
-            fmt = _ENCODING_ENDIAN + _ENCODING_LINEAR
-            offset = idx * struct.calcsize(fmt)
-            (values, _) = read_bin_data(
-                self._subset_map, offset, os.SEEK_SET, fmt, self._subset_lock
-            )
-            idx = values[0]
-
-        # File offset for the i-th sample
-        _, file_offset, (ilatitude, ilongitude, idepth, itime) = self._get_fileoffset(
-            idx
-        )
-        values, _ = read_bin_data(
-            self.fp,
-            file_offset,
-            os.SEEK_SET,
-            _ENCODING_ENDIAN + (self.row_format * self.num_days),
-            self.fp_lock,
-        )
-
-        values = np.array(values).reshape((self.num_days, -1))
-        assert np.all(
-            values[:, 0]
-        )  # This is always expected to be True since only valid
-        # locations can be indexed with our generated index map
-
-        # Preprend the latitude, longitude, depth and time
-        # TODO Take the time to think about whether this is relevant !
-        coordinates = np.ones((self.num_days, 4))
-        # print(coordinates.shape, self.times[itime : (itime + self.num_days)])
-        coordinates[:, 0] *= self.latitudes[ilatitude]
-        coordinates[:, 1] *= self.longitudes[ilongitude]
-        coordinates[:, 2] *= self.depths[idepth]
-        coordinates[:, 3] = self.times[itime : (itime + self.num_days)]
-
-        # Stick the values with the coordinates
-        # This np array is now (num_days , num_coordinates + num_values)
-        values = np.hstack([coordinates, values[:, 1:]])
-
-        if self.train:
-            in_values = torch.Tensor(values[:, :-1])
-            out_values = torch.Tensor(values[:, -1])
-
-            if self.transform is not None:
-                in_values = self.transform(in_values)
-
-            if self.target_transform is not None:
-                out_values = self.target_transform(out_values)
-
-            return in_values, out_values
-        else:
-            in_values = torch.Tensor(values[:, :-1])
-            if self.transform is not None:
-                in_values = self.transform(in_values)
-            return in_values
-
-    def __len__(self):
-        """
-        Returns the total number of valid datapoints for this dataset
-
-        This corresponds to the length of the index or subset index
-        """
-        # If we are processing a subset, the size of the dataset
-        # is the size of the subset
-        if self._subset_map is not None:
-            subset_size = os.path.getsize(self._subset_map.name)
-            return subset_size // struct.calcsize(_ENCODING_ENDIAN + _ENCODING_LINEAR)
-        else:
-            # Access the last of the index file
-            # and get its linear index. This linear index is also
-            # the total number of "rows" minus 1 in the file
-            lin_idx, _, _ = self._get_fileoffset(-1)
-            return lin_idx + 1
-
-
-def train_valid_split(
-    dataset,
-    valid_ratio,
-    train_subset_filepath,
-    valid_subset_filepath,
-    max_num_samples=None,
-):
-    """
-    Generates two splits of 1-valid_ratio and valid_ratio fractions
-    Takes 2 minutes for generating the splits for 50M datapoints
-
-    For smalls sets, we could use indices in main memory with
-    torch.utils.data.Subset
-    """
-
-    N = len(dataset)
-    if max_num_samples is not None:
-        pkeep = max_num_samples / N
-        Nvalid = int(valid_ratio * max_num_samples)/1000000
-        Ntrain = max_num_samples - Nvalid
-        ptrain = Ntrain / max_num_samples
-
-    else:
-        pkeep = 1.0
-        Nvalid = int(valid_ratio * N)
-        Ntrain = N - Nvalid
-        ptrain = Ntrain / N
-
-    ftrain = open(train_subset_filepath, "wb")
-    fvalid = open(valid_subset_filepath, "wb")
-
-    gen = np.random.default_rng()
-
-    logging.info(f"Generating the subset files from {N} samples")
-    for i in tqdm.tqdm(range(N)):
-        if gen.uniform() < pkeep:
-            if gen.uniform() < ptrain:
-                ftrain.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
-            else:
-                fvalid.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
-    fvalid.close()
-    ftrain.close()
-
-
-def get_dataloaders(
-    filepath,
-    num_days,
-    batch_size,
-    num_workers,
-    pin_memory,
-    valid_ratio,
-    overwrite_index=True,
-    train_transform=None,
-    train_target_transform=None,
-    valid_transform=None,
-    valid_target_transform=None,
-    max_num_samples=None,
-):
-    logging.info("= Dataloaders")
-    # Load the base dataset
-    logging.info("  - Dataset creation")
-    dataset = Dataset(
-        filepath, train=True, overwrite_index=overwrite_index, num_days=num_days
-    )
-    logging.info(f"  - Loaded a dataset with {len(dataset)} samples")
-
-    # Split the number of samples in each fold
-    logging.info("  - Splitting the data in training and validation sets")
-    train_subset_file = "train_indices.subset"
-    valid_subset_file = "valid_indices.subset"
-    train_valid_split(
-        dataset, valid_ratio, train_subset_file, valid_subset_file, max_num_samples
-    )
-
-    logging.info("  - Subset dataset")
-    train_dataset = Dataset(
-        filepath,
-        subset_file=train_subset_file,
-        transform=train_transform,
-        target_transform=train_target_transform,
-        num_days=num_days,
-    )
-    valid_dataset = Dataset(
-        filepath,
-        subset_file=valid_subset_file,
-        transform=valid_transform,
-        target_transform=valid_target_transform,
-        num_days=num_days,
-    )
-    # The sum of the two folds are not expected to be exactly of
-    # max_num_samples
-    logging.info(f"  - The train fold has {len(train_dataset)} samples")
-    logging.info(f"  - The valid fold has {len(valid_dataset)} samples")
-
-    # Build the dataloaders
-    train_loader = torch.utils.data.DataLoader(
-        train_dataset,
-        batch_size=batch_size,
-        shuffle=True,
-        num_workers=num_workers,
-        pin_memory=pin_memory,
-    )
-
-    valid_loader = torch.utils.data.DataLoader(
-        valid_dataset,
-        batch_size=batch_size,
-        shuffle=False,
-        num_workers=num_workers,
-        pin_memory=pin_memory,
-    )
-
-    return train_loader, valid_loader
-
-
-def get_test_dataloader(
-    filepath,
-    num_days,
-    batch_size,
-    num_workers,
-    pin_memory,
-    overwrite_index=True,
-    transform=None,
-    target_transform=None,
-):
-    logging.info("= Dataloaders")
-    # Load the base dataset
-    logging.info("  - Dataset creation")
-    test_dataset = Dataset(
-        filepath,
-        train=False,
-        transform=transform,
-        target_transform=target_transform,
-        num_days=num_days,
-        overwrite_index=overwrite_index,
-    )
-    logging.info(f"I loaded {len(test_dataset)} values in the test set")
-
-    # Build the dataloader, be carefull not to shuffle the data
-    # to ensure we keep the
-    test_loader = torch.utils.data.DataLoader(
-        test_dataset,
-        batch_size=batch_size,
-        shuffle=False,
-        num_workers=num_workers,
-        pin_memory=pin_memory,
-    )
-    return test_loader
-
-
-def test_point_dataset():
-    dataset = Dataset(_DEFAULT_TRAIN_FILEPATH, train=True, overwrite_index=True)
-    logging.info(f"Dataset loaded with {len(dataset)} samples")
-
-    xi, yi = dataset[103]
-    print(f"Loaded one sample with shapes of xi({xi.shape}), yi({yi.shape})")
-
-    # Test iterating over the whole dataset
-    # This takes around 1 hour to iterate over the 50M samples
-    # logging.info("Iterating over the whole dataset, sample by sample")
-    # for i in tqdm.tqdm(range(len(dataset))):
-    #     xi, yi = dataset[i]
-
-
-def test_getitem():
-    dataset = Dataset(_DEFAULT_TRAIN_FILEPATH, train=True, overwrite_index=True)
-    logging.info(f"Dataset loaded with {len(dataset)} samples")
-    idx = 113728
-
-    lin_index, file_offset, tab_indices = dataset._get_fileoffset(idx)
-    logging.info(
-        f"The idx {idx} corresponds to : \n\tlinear index={lin_index}\n\tfile offset={file_offset}\n\ttab indices={tab_indices}"
-    )
-    train_loader, valid_loader = get_dataloaders(
-        trainpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        valid_ratio,
-        overwrite_index=True,
-        max_num_samples=max_num_samples,
-    )
-
-def test_dataloader():
-    logging.info("====> Test dataloader")
-    use_cuda = torch.cuda.is_available()
-    trainpath = _DEFAULT_TRAIN_FILEPATH
-    num_days = 1  # Test with sequence of 1 day
-    batch_size = 128
-    num_workers = 7
-    valid_ratio = 0.2
-    # max_num_samples = 1000
-    max_num_samples = None
-
-    train_loader, valid_loader = get_dataloaders(
-        trainpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        valid_ratio,
-        overwrite_index=True,
-        max_num_samples=max_num_samples,
-    )
-
-    it = iter(train_loader)
-    X, Y = next(it)
-    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")
-
-    # This takes around 20 minutes for 50M samples
-    minis = []
-    maxis = []
-    logging.info("Iterating over the whole data loader")
-    for X, Y in tqdm.tqdm(train_loader):
-        minis.append(Y.min())
-        maxis.append(Y.max())
-    logging.info(f"phyc ranges in [{min(minis)},{max(maxis)}]")
-
-
-def test_time_dataset():
-    import matplotlib.pyplot as plt
-
-    use_cuda = torch.cuda.is_available()
-    trainpath = _DEFAULT_TRAIN_FILEPATH
-    num_days = 35
-
-    batch_size = 24
-    num_workers = 7
-    valid_ratio = 0.2
-    max_num_samples = None
-
-    train_dataset = Dataset(
-        trainpath,
-        overwrite_index=True,
-        train=True,
-        num_days=num_days,
-    )
-    logging.info(f"The training set contains {len(train_dataset)} samples")
-
-    # Test indexing one element
-    X, y = train_dataset[166]
-    logging.info(f"Got one sample of shape {X.shape}, {y.shape}")
-
-    # Iterate over the whole training set, around 2 minutes
-    # logging.info("Iterating over the whole dataset, sample by sample")
-    # for i in tqdm.tqdm(range(len(train_dataset))):
-    #     X, y = train_dataset[i]
-
-    # Create the dataloader for getting minibatches
-    train_loader, valid_loader = get_dataloaders(
-        trainpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        valid_ratio,
-        overwrite_index=True,
-        max_num_samples=max_num_samples,
-    )
-
-    it = iter(train_loader)
-    X, Y = next(it)
-    logging.info(f"Got one minibatch of shape {X.shape}, {Y.shape}")
-
-    plt.figure()
-    plt.plot(Y.numpy().T)
-    plt.show()
-
-    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")
-    logging.info(
-        "The tensors are in order (B, T, N) for the input and (B, T) for the output. Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
-    )
-    logging.info("Iterating over the whole train loader for testing the time it takes")
-    # Around 1 minute
-    for X, Y in tqdm.tqdm(train_loader):
-        continue
-
-
-def test_time_test_dataset():
-    import matplotlib.pyplot as plt
-
-    use_cuda = torch.cuda.is_available()
-    testpath = _DEFAULT_TEST_FILEPATH
-    num_days = 365
-
-    batch_size = 24
-    num_workers = 7
-
-    test_dataset = Dataset(
-        testpath,
-        overwrite_index=True,
-        train=False,
-        num_days=num_days,
-    )
-    logging.info(f"The test set contains {len(test_dataset)} samples")
-
-    # Test indexing one element
-    X = test_dataset[166]
-    logging.info(f"Got one sample of shape {X.shape}")
-
-    # Iterate over the whole training set, around 2 minutes
-    # logging.info("Iterating over the whole dataset, sample by sample")
-    # for i in tqdm.tqdm(range(len(test_dataset))):
-    #     X = test_dataset[i]
-
-    # Create the dataloader for getting minibatches
-    test_loader = get_test_dataloader(
-        testpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        overwrite_index=True,
-    )
-
-    it = iter(test_loader)
-    X = next(it)
-    logging.info(f"Got one minibatch of shape {X.shape}")
-
-    plt.figure()
-    plt.plot(X[:, :, 6].T)
-    plt.title("NH4")
-    plt.show()
-
-    logging.info(f"Got a minibatch of size {X.shape}")
-    logging.info(
-        "The tensors are in order (B, T, N). Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
-    )
-    logging.info("Iterating over the whole train loader for testing the time it takes")
-    # Around 1 minute
-    for X in tqdm.tqdm(test_loader):
-        continue
-
-
-def test_transforms():
-    import matplotlib.pyplot as plt
-
-    logging.info("====> Test transforms dataloader")
-    use_cuda = torch.cuda.is_available()
-    trainpath = _DEFAULT_TRAIN_FILEPATH
-    num_days = 834  # Test with sequence of 1 day
-    batch_size = 128
-    num_workers = 7
-    valid_ratio = 0.2
-    max_num_samples = 1000
-
-    def train_transform(X):
-        """
-        Transform to be applied to an input sample X of shape (T, N)
-        """
-        # The variables in X are
-        #     latitude, longitude, depth, time,
-        # followed by
-        # dissic, mlotst, nh4, no3, nppv, o2, ph, po4, so, talk, thetao, uo
-        # vo, zos
-
-        # As a matter of example, we drop some input variables
-        # For example, we drop the latitude and longitude
-        return X[:, 2:]
-
-    def valid_transform(X):
-        return X
-
-    def target_transform(Y):
-        return Y.log()
-
-    train_loader, _ = get_dataloaders(
-        trainpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        valid_ratio,
-        overwrite_index=True,
-        max_num_samples=max_num_samples,
-        train_target_transform=target_transform,
-        train_transform=train_transform,
-        valid_target_transform=target_transform,
-        valid_transform=valid_transform,
-    )
-
-    it = iter(train_loader)
-    X, Y = next(it)
-    logging.info(f"Got one minibatch of shape {X.shape}, {Y.shape}")
-    logging.info(
-        "The tensors are in order (B, T, N) for the input and (B, T) for the output. Be carefull when using convolutional layers where 1D convolutions expects (B, N, T). Be carefull when using recurrent layers which are by default Time first"
-    )
-
-    plt.figure()
-
-    plt.plot(list(map(datetime.fromtimestamp, X[0, :, 1].tolist())), Y[0, :])
-    plt.xticks(rotation=70)
-    plt.tight_layout()
-    plt.show()
-
-
-if __name__ == "__main__":
-    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")
-
-    logging.info("\n" + "=" * 80)
-    test_point_dataset()
-
-    logging.info("\n" + "=" * 80)
-    test_getitem()
-
-    logging.info("\n" + "=" * 80)
-    test_dataloader()
-
-    logging.info("\n" + "=" * 80)
-    test_time_dataset()
-
-    logging.info("\n" + "=" * 80)
-    test_time_test_dataset()
-
-    logging.info("\n" + "=" * 80)
-    test_transforms()
diff --git a/config.yml b/config.yml
index 1120a1b09bc4ea0dac7c292003e3b4ca204b81df..783ce55dd53d9d48f5d4781e62a6feb61e4aa60e 100644
--- a/config.yml
+++ b/config.yml
@@ -1,6 +1,6 @@
 # Dataset Configuration
 Dataset:
-  num_days: 73 # Test with sequence of 1 day - should be the same as in Test -
+  num_days: 73 # Number of days in each sequence - should be the same as in Test -
   batch_size: 64
   num_workers: 7
   valid_ratio: 0.2
@@ -48,7 +48,7 @@ Training:
 #Model selection
 Model:
   Name: BidirectionalLSTM
-  #choose in {LinearRegression, BidirectionalLSTM, RNN}
+  #choose in {LinearRegression, BidirectionalLSTM, RNN, CNN1D}
 
 #Model parameters selection
 LinearRegression:
diff --git a/create_submission.py b/create_submission.py
index fba65a17abf6b0441d34249acc73c9399898e985..0ce8ac788d56f642971666fc71adfa3df6feb0eb 100644
--- a/create_submission.py
+++ b/create_submission.py
@@ -21,6 +21,7 @@ import tqdm
 import torch
 import torch.nn as nn
 import argparse
+import yaml
 
 # Local imports
 import dataloader
@@ -31,22 +32,24 @@ def dummy_model(X):
     # Divided by a magic number
     return X[:, :, 4:].mean(dim=2) / 26  # This is (B, T)
 
-def create_submission(model, transform, device, rootDir, logdir):
+def create_submission(args, model, transform, device, rootDir, logdir):
+    cfg = yaml.load(config_file, Loader=yaml.FullLoader)
     step_days = 10
     batch_size = 1024
     # We make chunks of num_days consecutive samples; As our dummy predictor
     # is not using the temporal context, this is here arbitrarily chosen
     # However, note that it must be a divisor of the total number of days
     # in the 2017 year , either 1, 5, 73 or 365
-    num_days = 73
+    num_days = cfg["Dataset"]["num_days"]
     num_workers = 7
 
     use_cuda = torch.cuda.is_available()
     # Build the dataloaders
     logging.info("Building the dataloader")
 
-    test_loader = dataloader.get_test_dataloader(
-        dataloader._DEFAULT_TEST_FILEPATH,
+    if args.PATHTOTESTSET != None:
+            test_loader = dataloader.get_test_dataloader(
+        args.PATHTOTESTSET,
         num_days,
         batch_size,
         num_workers,
@@ -55,6 +58,17 @@ def create_submission(model, transform, device, rootDir, logdir):
         transform=transform,
         target_transform=None,
     )
+    else :
+        test_loader = dataloader.get_test_dataloader(
+            dataloader._DEFAULT_TEST_FILEPATH,
+            num_days,
+            batch_size,
+            num_workers,
+            use_cuda,
+            overwrite_index=True,
+            transform=transform,
+            target_transform=None,
+        )
     num_days_test = test_loader.dataset.ntimes
 
     logging.info("= Filling in the submission file")
diff --git a/debug.py b/debug.py
deleted file mode 100644
index 0ecb5f4a954ac10770479ea6d83e402229bacf66..0000000000000000000000000000000000000000
--- a/debug.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from dataset import Dataset
-import bindataset
-
-_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
-_DEFAULT_TEST_FILEPATH = (
-    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
-)
-
-idx ="sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx"
-
-data = Dataset(_DEFAULT_TRAIN_FILEPATH, overwrite_index = False, train = False, subset_file = idx, num_days = 20, transform = None, target_transform = None)
-
-
-"""
-Builds a pointdataset generating the index if necessary or requested
-
-Arguments:
-    filepath: the full path to the nc file to load
-    overwrite_index: if True ignores the index and regenerates it
-    train: if True, accessing an element also returns the phyc
-    subset_file: a filename which holds a list of indices this dataset must use
-    num_days: the number of days that each sample considers
-    transform: a transform to apply to the input tensor
-    target_transform: a transform to apply to the phyc output tensor
-"""
-
-
-print("Len whole dataset :") 
-print(len(data))
-print()
-
-print("Shape data[0] : ")
-print(data[0].shape)
-
-print(data.in_variables)
-print(len(data.in_variables))
-
-
-bindataset.test_time_dataset()
\ No newline at end of file
diff --git a/logs/main_unit_test.log b/logs/main_unit_test.log
index 950ba21aa3f1c26e95b34f1b37fcf47d46e3eb8e..38daef58f7cd0d49949d994431bc7b5dea593041 100644
--- a/logs/main_unit_test.log
+++ b/logs/main_unit_test.log
@@ -2557,18 +2557,19 @@ INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and
 INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
 INFO:root:  - The train fold has 541712 samples
 INFO:root:  - The valid fold has 135448 samples
+��������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������INFO:root:Building the dataloader
 INFO:root:= Dataloaders
 INFO:root:  - Dataset creation
-INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
+INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 365 time points
 INFO:root:Generating the index
-INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
-INFO:root:  - Loaded a dataset with 677160 samples
-INFO:root:  - Splitting the data in training and validation sets
-INFO:root:Generating the subset files from 677160 samples
-INFO:root:  - Subset dataset
-INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
-INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
-INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
-INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
-INFO:root:  - The train fold has 541661 samples
-INFO:root:  - The valid fold has 135499 samples
+INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2017-testing.nc.bin_index.idx
+INFO:root:I loaded 112860 values in the test set
+INFO:root:= Filling in the submission file
+INFO:root:Building the dataloader
+INFO:root:= Dataloaders
+INFO:root:  - Dataset creation
+INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 365 time points
+INFO:root:Generating the index
+INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2017-testing.nc.bin_index.idx
+INFO:root:I loaded 112860 values in the test set
+INFO:root:= Filling in the submission file
diff --git a/losses.py b/losses.py
index 435d124b376feeb8d80246378b41e27582340f54..281976d7b54980269fc187addf43b62e9fbc4c54 100644
--- a/losses.py
+++ b/losses.py
@@ -1,11 +1,2 @@
 import torch.nn as nn
-import torch
-
-class RMSLELoss(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.mse = nn.MSELoss()
-        
-    def forward(self, pred, actual):
-        resized_actual =  actual.view([actual.shape[0], actual.shape[1],1])
-        return torch.sqrt(self.mse(torch.log(torch.add(pred,1)), torch.log(torch.add(resized_actual, 1))))
\ No newline at end of file
+import torch
\ No newline at end of file
diff --git a/losses/RMSLE.py b/losses/RMSLE.py
deleted file mode 100644
index cf5cff69c2f0421cdce0b80bf171d13e6b57489e..0000000000000000000000000000000000000000
--- a/losses/RMSLE.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import torch.nn as nn
-
-class RMSLELoss(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.mse = nn.MSELoss()
-        
-    def forward(self, pred, actual):
-        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))
\ No newline at end of file
diff --git a/main.py b/main.py
index bab93c2e2ba3db40fb129ce9e77836f68fdf6be6..76645eb3625939e1d841fe78ea9f6dd200eb754d 100644
--- a/main.py
+++ b/main.py
@@ -2,9 +2,7 @@
 import dataloader
 import model
 import test
-from train import train
-import losses
-import optimizers
+import my_train
 import create_submission
 import utils
 
@@ -22,33 +20,14 @@ def optimizer(cfg, network):
     result = {"Adam" : torch.optim.Adam(network.parameters())}
     return result[cfg["Optimizer"]]
 
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-
-    parser.add_argument(
-    "--no_wandb",
-    action="store_true",
-    help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
-    )
-
-    parser.add_argument(
-    "--rootDir",
-    default=None,
-    help="Directory in which the log files will be stored"
-    )
-
-    args = parser.parse_args()
-
-    config_file = open("config.yml")
-    cfg = yaml.load(config_file, Loader=yaml.FullLoader)
-
+def train(args, cfg):
     rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]
 
     logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
     
 
     use_cuda = torch.cuda.is_available()
-    trainpath           = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
+    trainpath           = args.PATHTOTRAININGSET if args.PATHTOTRAININGSET != None else cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
     num_days            = int(cfg["Dataset"]["num_days"])
     batch_size          = int(cfg["Dataset"]["batch_size"])
     num_workers         = int(cfg["Dataset"]["num_workers"])
@@ -110,7 +89,7 @@ if __name__ == "__main__":
 
     model.initialize_model(cfg, network)
 
-    f_loss = losses.RMSLELoss()
+    f_loss = model.RMSLELoss()
 
     optimizer = optimizer(cfg, network)
 
@@ -129,22 +108,105 @@ if __name__ == "__main__":
         wandb.run.name = raw_run_name
         wandb.watch(network, log_freq = log_freq)
 
-    #torch.autograd.set_detect_anomaly(True)
+    if args.detect_anomaly:
+        torch.autograd.set_detect_anomaly(True)
 
+    best_val_loss = None
     for t in range(cfg["Training"]["Epochs"]):
         print("Epoch {}".format(t))
-        train(args, network, train_loader, f_loss, optimizer, device, log_interval)
+        my_train.train(args, network, train_loader, f_loss, optimizer, device, log_interval)
 
         val_loss = test.test(network, valid_loader, f_loss, device)
 
+        if best_val_loss != None:
+            if val_loss < best_val_loss :
+                network_checkpoint.update(val_loss)
+
         scheduler.step(val_loss)
 
-        network_checkpoint.update(val_loss)
 
         print(" Validation : Loss : {:.4f}".format(val_loss))
         if not args.no_wandb:
             wandb.log({"val_loss": val_loss})
 
     utils.write_summary(logdir, network, optimizer, val_loss)
+
+    logging.info(f"Best model saved in folder {logdir}")
+
+
+def test(args):
+
+    dataset_transform = cfg["Dataset"]["Transform"]
+    rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]
+
+    if use_cuda :
+        device = torch.device('cuda')
+    else :
+        device = toch.device('cpu')
+
+    logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
+
+    model_path = args.PATHTOCHECKPOINT
+
+    network = model.build_model(cfg, 14)
+
+    network = model.to(device)
+
+    network.load_state_dict(torch.load(model_path))
+
+    create_submission.create_submission(args, network, eval(dataset_transform), device, rootDir, logdir)
+
+    logging.info(f"The submission csv file has been created in the folder : {logdir}")
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+    "--no_wandb",
+    action="store_true",
+    help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
+    )
+
+    parser.add_argument(
+    "--detect_anomaly",
+    action="store_true",
+    help="If specified, torch.autograd.set_detect_anomaly(True) will be activated",
+    )
+
+    parser.add_argument(
+    "--rootDir",
+    default=None,
+    help="Directory in which the log files will be stored"
+    )
+
+    parser.add_argument(
+        "--PATHTOTESTSET",
+    default=None,
+    help="Path of the file on which the model will be tested on"
+    )
+
+    parser.add_argument(
+        "--PATHTOTRAININGSET",
+        default=None,
+        help="Path of the file on which the model will be trained on"
+    )
+
+    parser.add_argument(
+        "--PATHTOCHECKPOINT",
+        default="./logs/BestBidirectionalLSTM/best_model.pt",
+        help="Path of the model to load"
+    )
+
+    parser.add_argument(
+        "command", 
+        choices=["train", "test"]
+    )
+
+    args = parser.parse_args()
     
-    create_submission.create_submission(network, eval(dataset_transform), device, rootDir, logdir)
+    config_file = open("config.yml")
+    cfg = yaml.load(config_file, Loader=yaml.FullLoader)
+    
+    eval(f"{args.command}(args)")
+
diff --git a/model.py b/model.py
index e9e91efe40c647580d499f8b8ff3a5574e8412e2..2fd3b44e67a4f06f284337c37b15f3caace918c2 100644
--- a/model.py
+++ b/model.py
@@ -167,7 +167,7 @@ def init_xavier(module):
 # Generic function to build model
 
 def build_model(cfg, input_size):    
-    print(f"{cfg['Model']['Name']}(cfg, input_size)")
+    print(f"The model used is : {cfg['Model']['Name']}(cfg, input_size)")
     return eval(f"{cfg['Model']['Name']}(cfg, input_size)")
 
 def initialize_model(cfg, network):
@@ -176,6 +176,17 @@ def initialize_model(cfg, network):
             #print(f"{cfg[cfg['Model']['Name']]['Initialization']}")
             layer.apply(eval(f"{cfg[cfg['Model']['Name']]['Initialization']}"))
 
+# Loss function
+class RMSLELoss(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.mse = nn.MSELoss()
+        
+    def forward(self, pred, actual):
+        resized_actual =  actual.view([actual.shape[0], actual.shape[1],1])
+        return torch.sqrt(self.mse(torch.log(torch.add(pred,1)), torch.log(torch.add(resized_actual, 1))))
+
+# Saving the best model
 class ModelCheckpoint:
     def __init__(self, filepath, model):
         self.min_loss = None
@@ -192,4 +203,5 @@ if __name__== "__main__":
     import yaml
     config_file = open("config.yml","r")
     cfg = yaml.load(config_file)
-    print(cfg['Model']['Name'])
\ No newline at end of file
+    print(cfg['Model']['Name'])
+
diff --git a/optimizers.py b/optimizers.py
deleted file mode 100644
index a696044fba31703f1f29f81818419f944f86edcf..0000000000000000000000000000000000000000
--- a/optimizers.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import torch.optim
-
-def optimizer(cfg, model):
-    result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-3)}
-    return result[cfg["Optimizer"]]
diff --git a/train.py b/train.py
deleted file mode 100644
index 778b6c7e356fa2152a9edc82e92b40946d401639..0000000000000000000000000000000000000000
--- a/train.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from tqdm import tqdm
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-import wandb
-
-def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
-    """
-    Train a model for one epoch, iterating over the loader
-    using the f_loss to compute the loss and the optimizer
-    to update the parameters of the model.
-
-    Arguments :
-
-        model     -- A torch.nn.Module object
-        loader    -- A torch.utils.data.DataLoader
-        f_loss    -- The loss function, i.e. a loss Module
-        optimizer -- A torch.optim.Optimzer object
-        device    -- a torch.device class specifying the device
-                     used for computation
-
-    Returns :
-    """
-
-    model.train()
-    gradients = []
-    out         = []
-    tar         = [] 
-    for batch_idx, (inputs, targets) in tqdm(enumerate(loader), total = len(loader)):
-        inputs, targets = inputs.to(device), targets.to(device)
-
-        # Compute the forward pass through the network up to the loss
-
-        # target's shape is (B, Num_days, 1)
-        outputs = model(inputs)
-        loss = f_loss(outputs, targets)
-
-        # Backward and optimize
-        optimizer.zero_grad()
-        loss.backward()
-
-        #torch.nn.utils.clip_grad_norm(model.parameters(), 50)
-        
-        Y = list(model.parameters())[0].grad.cpu().tolist()
-        
-
-        if not args.no_wandb:
-            if batch_idx % log_interval == 0:
-                wandb.log({"train_loss" : loss})
-        optimizer.step()
-
-def visualize_gradients(gradients):
-    print(gradients)
-    import numpy as np
-    X = np.linspace(0,len(gradients),len(gradients))
-    plt.scatter(X,gradients)
-    plt.show()
-
-if __name__=="__main__":
-    import numpy as np
-    Y = [[1,2,3],[2,4,8],[2,5,6], [8,9,10]]
-    X = np.linspace(0,len(Y),len(Y))
-    for i,curve in enumerate(Y):
-        for point in curve : 
-            plt.scatter(X[i],point)
-    plt.show()
\ No newline at end of file