Skip to content
Snippets Groups Projects
dataloader.py 6.97 KiB
Newer Older
Yandi's avatar
Yandi committed
# Standard imports
import os
import sys
import logging
import pathlib
from typing import Union
import struct

from datetime import datetime

# import multiprocessing
from multiprocessing import Lock

# from multiprocessing.synchronize import Lock

# External imports
import numpy as np
import tqdm
import torch
import torch.utils.data as data

# Imports from the project
from dataset import Dataset

_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
_DEFAULT_TEST_FILEPATH = (
    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
)

_ENCODING_LINEAR = "I"
_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
_ENCODING_OFFSET_FORMAT = ""
_ENCODING_ENDIAN = "<"

def train_valid_split(
    dataset,
    valid_ratio,
    train_subset_filepath,
    valid_subset_filepath,
    max_num_samples=None,
):
    """
    Generates two splits of 1-valid_ratio and valid_ratio fractions
    Takes 2 minutes for generating the splits for 50M datapoints

    For smalls sets, we could use indices in main memory with
    torch.utils.data.Subset
    """

    N = len(dataset)
    if max_num_samples is not None:
        pkeep = max_num_samples / N
        Nvalid = int(valid_ratio * max_num_samples)/1000000
        Ntrain = max_num_samples - Nvalid
        ptrain = Ntrain / max_num_samples

    else:
        pkeep = 1.0
        Nvalid = int(valid_ratio * N)
        Ntrain = N - Nvalid
        ptrain = Ntrain / N

    ftrain = open(train_subset_filepath, "wb")
    fvalid = open(valid_subset_filepath, "wb")

    gen = np.random.default_rng()

    logging.info(f"Generating the subset files from {N} samples")
    for i in tqdm.tqdm(range(N)):
        if gen.uniform() < pkeep:
            if gen.uniform() < ptrain:
                ftrain.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
            else:
                fvalid.write(struct.pack(_ENCODING_ENDIAN + _ENCODING_LINEAR, i))
    fvalid.close()
    ftrain.close()


def get_dataloaders(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    valid_ratio,
    overwrite_index=True,
    train_transform=None,
    train_target_transform=None,
    valid_transform=None,
    valid_target_transform=None,
    max_num_samples=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    dataset = Dataset(
        filepath, train=True, overwrite_index=overwrite_index, num_days=num_days
    )
    logging.info(f"  - Loaded a dataset with {len(dataset)} samples")

    # Split the number of samples in each fold
    logging.info("  - Splitting the data in training and validation sets")
    train_subset_file = "train_indices.subset"
    valid_subset_file = "valid_indices.subset"
    train_valid_split(
        dataset, valid_ratio, train_subset_file, valid_subset_file, max_num_samples
    )

    logging.info("  - Subset dataset")
    train_dataset = Dataset(
        filepath,
        subset_file=train_subset_file,
        transform=train_transform,
        target_transform=train_target_transform,
        num_days=num_days,
    )
    valid_dataset = Dataset(
        filepath,
        subset_file=valid_subset_file,
        transform=valid_transform,
        target_transform=valid_target_transform,
        num_days=num_days,
    )
Yandi's avatar
Yandi committed

Yandi's avatar
Yandi committed
    # The sum of the two folds are not expected to be exactly of
    # max_num_samples
    logging.info(f"  - The train fold has {len(train_dataset)} samples")
    logging.info(f"  - The valid fold has {len(valid_dataset)} samples")

    # Build the dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, valid_loader


def get_test_dataloader(
    filepath,
    num_days,
    batch_size,
    num_workers,
    pin_memory,
    overwrite_index=True,
    transform=None,
    target_transform=None,
):
    logging.info("= Dataloaders")
    # Load the base dataset
    logging.info("  - Dataset creation")
    test_dataset = Dataset(
        filepath,
        train=False,
        transform=transform,
        target_transform=target_transform,
        num_days=num_days,
        overwrite_index=overwrite_index,
    )
    logging.info(f"I loaded {len(test_dataset)} values in the test set")

    # Build the dataloader, be carefull not to shuffle the data
    # to ensure we keep the
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    return test_loader

def transform_remove_space_time():
    return lambda attributes : attributes[:,4:]

if __name__ == "__main__":
    logging.basicConfig(filename='logs/dataloader_unit_test.log', level=logging.INFO)
    logging.info("====> Test dataloader")
    use_cuda = torch.cuda.is_available()
    trainpath = _DEFAULT_TRAIN_FILEPATH
    num_days = 35  # Test with sequence of 1 day
    batch_size = 128
    num_workers = 7
    valid_ratio = 0.2
    # max_num_samples = 1000
    max_num_samples = None

    train_loader, valid_loader = get_dataloaders(
        filepath = trainpath,
        num_days = num_days,
        batch_size = batch_size,
        num_workers = num_workers,
Yandi's avatar
Yandi committed
        pin_memory = True,
        valid_ratio = valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )

    it = iter(train_loader)
    X, Y = next(it)

    def check_min(tensor, index):
        """
        For a tensor of shape (B, T, N) return the min value of N[index]
        """
        mini = 1e99
        for batch in tensor:
            for value_time in batch:
                value = value_time[index].item()
                if value < mini :
                    mini = value
        print(f"Min value for index {index} is ")
        return mini

    def check_max(tensor, index):
        """
        For a tensor of shape (B, T, N) return the max value of N[index]
        """
        maxi = -1e99
        for batch in tensor:
            for value_time in batch:
                value = value_time[index].item()
                if value > maxi :
                    maxi = value
        print(f"Max value for index {index} is ")
        return maxi

    def check_info(tensor, index):
        print("="*30)
        print(check_min(tensor, index))
        print(check_max(tensor, index))
        print("="*30)

Yandi's avatar
Yandi committed
    print(X.shape)

    check_info(X,0) #latitude
    check_info(X,1) #longitude
    check_info(X,2) #depth
    check_info(X,3) #time

    """
    Size of X = (Batchsize, Number of days, Features (18) (lat, long, depth, time, features))
    Size of Y = (Batchsize, Number of days)
    """
    logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")