Skip to content
Snippets Groups Projects
dataloader.py 11.38 KiB
# Standard imports
import os
import sys
import logging
import pathlib
from typing import Union
import struct

from datetime import datetime

# import multiprocessing
from multiprocessing import Lock

# from multiprocessing.synchronize import Lock

# External imports
import numpy as np
import tqdm
import torch
import torch.utils.data as data

# Imports from the project
from dataset import Dataset

_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
_DEFAULT_TEST_FILEPATH = (
    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
)

_ENCODING_LINEAR = "I"
_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
_ENCODING_OFFSET_FORMAT = ""
_ENCODING_ENDIAN = "<"

def train_valid_split(
    dataset,
    valid_ratio,
    train_subset_filepath,
    valid_subset_filepath,
    max_num_samples=None,
):
    """
    Generates two splits of 1-valid_ratio and valid_ratio fractions
    Takes 2 minutes for generating the splits for 50M datapoints

    For smalls sets, we could use indices in main memory with
    torch.utils.data.Subset
    """

    N = len(dataset)
    if max_num_samples is not None:
        pkeep = max_num_samples / N
        Nvalid = int(valid_ratio * max_num_samples)/1000000
        Ntrain = max_num_samples - Nvalid
        ptrain = Ntrain / max_num_samples

    else:
        pkeep = 1.0
        Nvalid = int(valid_ratio * N)
        Ntrain = N - Nvalid
        ptrain = Ntrain / N

    ftrain = open(train_subset_filepath, "wb")
    fvalid = open(valid_subset_filepath, "wb")

    gen = np.random.default_rng()

    logging.info(f"Generating the subset files from {N} samples")
    for i in tqdm.tqdm(range(N)):
        if gen.uniform() < pkeep:
            if gen.uniform() < ptrain:
                ftrain.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
            else:
                fvalid.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
    fvalid.close()
    ftrain.close()


def get_dataloaders(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    valid_ratio,
    overwrite_index=True,
    train_transform=None,
    train_target_transform=None,
    valid_transform=None,
    valid_target_transform=None,
    max_num_samples=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    dataset = Dataset(
        filepath, train=True, overwrite_index=overwrite_index, num_days=num_days
    )
    logging.info(f"  - Loaded a dataset with {len(dataset)} samples")

    # Split the number of samples in each fold
    logging.info("  - Splitting the data in training and validation sets")
    train_subset_file = "train_indices.subset"
    valid_subset_file = "valid_indices.subset"
    train_valid_split(
        dataset, valid_ratio, train_subset_file, valid_subset_file, max_num_samples
    )

    logging.info("  - Subset dataset")
    train_dataset = Dataset(
        filepath,
        subset_file=train_subset_file,
        transform=train_transform,
        target_transform=train_target_transform,
        num_days=num_days,
    )
    valid_dataset = Dataset(
        filepath,
        subset_file=valid_subset_file,
        transform=valid_transform,
        target_transform=valid_target_transform,
        num_days=num_days,
    )

    # The sum of the two folds are not expected to be exactly of
    # max_num_samples
    logging.info(f"  - The train fold has {len(train_dataset)} samples")
    logging.info(f"  - The valid fold has {len(valid_dataset)} samples")

    # Build the dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, valid_loader

def get_stats_train_dataset(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    valid_ratio,
    overwrite_index=True,
    train_transform=None,
    train_target_transform=None,
    valid_transform=None,
    valid_target_transform=None,
    max_num_samples=None,
):
    logging.info("= Dataloaders for mean and standard deviation")
    # Load the base dataset
    logging.info("  - Dataset creation")
    dataset = Dataset(
        filepath, train=True, overwrite_index=overwrite_index, num_days=num_days
    )
    logging.info(f"  - Loaded a dataset with {len(dataset)} samples")

    # Split the number of samples in each fold
    logging.info("  - Splitting the data in training and validation sets")
    train_subset_file = "train_indices.subset"
    valid_subset_file = "valid_indices.subset"
    train_valid_split(
        dataset, valid_ratio, train_subset_file, valid_subset_file, max_num_samples
    )

    logging.info("  - Subset dataset")
    train_dataset = Dataset(
        filepath,
        subset_file=train_subset_file,
        transform=train_transform,
        target_transform=train_target_transform,
        num_days=num_days,
    )

    # The sum of the two folds are not expected to be exactly of
    # max_num_samples
    logging.info(f"  - The train fold has {len(train_dataset)} samples")

    # Build the dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    mean = 0.
    std = 0.
    nb_samples = 0.
    it = iter(train_loader)
    X, Y = next(it)
    maxi = X.max(dim=0).values.max(dim=0).values
    mini = X.min(dim=0).values.min(dim=0).values

    print("Computing the statistics of the train dataset")

    # Iterate over the data
    for data in tqdm.tqdm(train_loader, total = len(train_loader)):
        # Update the total sum
        mean += data[0].mean(dim=0).mean(dim=0) * data[0].size(0)
        std += data[0].std(dim=0).mean(dim=0) * data[0].size(0)
        #Update mini and maxi
        maxi = torch.stack((data[0].max(dim=0).values.max(dim=0).values, maxi)).max(dim=0).values
        mini = torch.stack((data[0].min(dim=0).values.min(dim=0).values, mini)).min(dim=0).values
        # Update the number of samples
        nb_samples += data[0].size(0)

    # Calculate the mean and std
    mean /= nb_samples
    std /= nb_samples

    return mean, std, maxi, mini

def get_test_dataloader(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    overwrite_index=True,
    transform=None,
    target_transform=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    test_dataset = Dataset(
        filepath,
        train=False,
        transform=transform,
        target_transform=target_transform,
        num_days=num_days,
        overwrite_index=overwrite_index,
    )
    logging.info(f"I loaded {len(test_dataset)} values in the test set")

    # Build the dataloader, be carefull not to shuffle the data
    # to ensure we keep the
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return test_loader

def transform_remove_space_time():
    return lambda attributes : attributes[:,4:]

def transform_normalize_with_train_statistics(MEAN, STD):
    return lambda attributes : torch.div(torch.sub(attributes, MEAN), STD)

def transform_min_max_scaling(MIN, MAX):
    return lambda attributes : torch.div(torch.sub(attributes, MIN), torch.sub(MAX,MIN))
    
def composite_transform(f,g):
    return lambda attributes : f(g(attributes))

if __name__ == "__main__":
    logging.basicConfig(filename='logs/dataloader_unit_test.log', level=logging.INFO)
    logging.info("====> Test dataloader")
    use_cuda = torch.cuda.is_available()
    trainpath = _DEFAULT_TRAIN_FILEPATH
    num_days = 35  # Test with sequence of 1 day
    batch_size = 128
    num_workers = 7
    valid_ratio = 0.2
    # max_num_samples = 1000
    max_num_samples = None

    MEAN = torch.tensor([ 4.2457e+01,  7.4651e+00,  1.6738e+02,  1.3576e+09,  2.3628e+00, 4.6839e+01,  2.3855e-01,  3.6535e+00,  1.9776e+00,  2.2628e+02, 8.1003e+00,  1.8691e-01,  3.8384e+01,  2.6626e+00,  1.4315e+01,-4.1419e-03,  6.0274e-03, -5.1017e-01])

    STD = torch.tensor([5.8939e-01, 8.1625e-01, 1.4535e+02, 5.4952e+07, 1.7543e-02, 1.3846e+02,\
        2.1302e-01, 1.9558e+00, 4.1455e+00, 1.2408e+01, 2.2938e-02, 9.9070e-02,\
        1.9490e-01, 9.2847e-03, 2.2575e+00, 8.5310e-02, 7.8280e-02, 8.6237e-02])

    train_loader, valid_loader = get_dataloaders(
        filepath = trainpath,
        num_days = num_days,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = True,
        valid_ratio = valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
        train_transform =composite_transform(transform_remove_space_time(), transform_normalize_with_train_statistics(MEAN, STD))
    )

    
    mean, std, maxi, mini = get_stats_train_dataset(        
        filepath = trainpath,
        num_days = num_days,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = True,
        valid_ratio = valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )
    
    print("Mini = ")
    print(mini)

    print("Maxi =")
    print(maxi)

    it = iter(train_loader)
    X, Y = next(it)

    print("Size of X :")
    print(X.shape) # (128, 35, 18)

    ####### Test to see what is inside a sample : timeframe and localization
    #
    print("Latitudes")
    print(torch.min(X[0, :, 0]), torch.max(X[0, :, 0])) 

    print("Longitudes")
    print(torch.min(X[0, :, 1]), torch.max(X[0, :, 1]))

    print("Depths")
    print(torch.min(X[0, :, 2], torch.max(X[0, :, 2])))

    print("Times")
    print(torch.min(X[0, : ,3]), torch.max(X[0, :, 3]))

    print("="*30)
    print("Size of Y :")
    print(Y.shape) # (128, 35)
    print("Y for first elem of batch :")
    print(Y[0,:])

    print("-"*30)
    def check_min(tensor, index):
        """
        For a tensor of shape (B, T, N) return the min value of N[index]
        """
        mini = 1e99
        for batch in tensor:
            for value_time in batch:
                value = value_time[index].item()
                if value < mini :
                    mini = value
        print(f"Min value for index {index} is ")
        return mini

    def check_max(tensor, index):
        """
        For a tensor of shape (B, T, N) return the max value of N[index]
        """
        maxi = -1e99
        for batch in tensor:
            for value_time in batch:
                value = value_time[index].item()
                if value > maxi :
                    maxi = value
        print(f"Max value for index {index} is ")
        return maxi

    def check_info(tensor, index):
        print("="*30)
        print(check_min(tensor, index))
        print(check_max(tensor, index))
        print("="*30)

    print(X.shape)

    check_info(X,0) #latitude
    check_info(X,1) #longitude
    check_info(X,2) #depth
    check_info(X,3) #time

    """
    Size of X = (Batchsize, Number of days, Features (18) (lat, long, depth, time, features))
    Size of Y = (Batchsize, Number of days)
    """
    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")