Skip to content
Snippets Groups Projects
dataset.py 13.10 KiB
# Standard imports
import os
import sys
import logging
import pathlib
from typing import Union
import struct

from datetime import datetime

# import multiprocessing
from multiprocessing import Lock

# from multiprocessing.synchronize import Lock

# External imports
import numpy as np
import tqdm
import torch
import torch.utils.data as data

_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
_DEFAULT_TEST_FILEPATH = (
    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
)

_ENCODING_LINEAR = "I"
_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
_ENCODING_OFFSET_FORMAT = ""
_ENCODING_ENDIAN = "<"


class Dataset(data.Dataset):
    """
    Dataset for training a predictor predicting the phytoplankton
    density from (a series of) environmental variables
    + time + depth + latitude + longitude

    This dataset contains 50 154 984 valid data points
    sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin

    Generating the index takes approximately 20 seconds
    """

    def __init__(
        self,
        filepath: Union[pathlib.Path, str],
        overwrite_index: bool = False,
        train=True,
        subset_file=None,
        num_days=1,
        transform=None,
        target_transform=None,
    ):
        """
        Builds a pointdataset generating the index if necessary or requested

        Arguments:
            filepath: the full path to the nc file to load
            overwrite_index: if True ignores the index and regenerates it
            train: if True, accessing an element also returns the phyc
            subset_file: a filename which holds a list of indices this dataset must use
            num_days: the number of days that each sample considers
            transform: a transform to apply to the input tensor
            target_transform: a transform to apply to the phyc output tensor
        """
        super().__init__()
        if isinstance(filepath, str):
            filepath = pathlib.Path(filepath)
        self.filepath = filepath

        # Load the dataset. It can be indexed by [lat, long, depth, time]
        self.fp = open(filepath, "rb")
        # self.fp.close()
        self.fp_lock = Lock()

        # List all the environmental variables
        # To which we could add time, depth, latitude, longitude
        self.in_variables = [
            "dissic",
            "mlotst",
            "nh4",
            "no3",
            "nppv",
            "o2",
            "ph",
            "po4",
            "so",
            "talk",
            "thetao",
            "uo",
            "vo",
            "zos",
            # "phyc",
        ]
        self.train = train
        if self.train:
            self.out_variable = "phyc"

        # Store the size of the timeseries, in days
        self.num_days = num_days

        # Set up the format of a row
        self.row_format = (
            "?" + "f" * len(self.in_variables) + ("f" if self.train else "")
        )
        self.row_size = struct.calcsize(_ENCODING_ENDIAN + self.row_format)

        # Load the header of the file for the dimension variables
        # Local utilitary function to parse the header
        def _read_dim(fp, offset, base_format, lock):
            fmt = _ENCODING_ENDIAN + "i"
            (dim, nbytes_dim) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)

            dim = dim[0]  # the returned values is a tuple
            offset += nbytes_dim

            fmt = _ENCODING_ENDIAN + (base_format * dim)
            (values, nbytes_values) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
            return dim, np.array(values), nbytes_dim + nbytes_values

        self.header_offset = 0
        self.nlatitudes, self.latitudes, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.nlongitudes, self.longitudes, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.ndepths, self.depths, nbytes = _read_dim(
            self.fp, self.header_offset, "f", self.fp_lock
        )
        self.header_offset += nbytes

        self.ntimes, self.times, nbytes = _read_dim(
            self.fp, self.header_offset, "i", self.fp_lock
        )
        # If we need to convert timestamps to datetimes
        # self.dtimes = np.array([datetime.fromtimestamp(di) for di in dtimes.tolist()])
        self.header_offset += nbytes

        self.lat_chunk_size = self.nlongitudes * self.ndepths * self.ntimes
        self.lon_chunk_size = self.ndepths * self.ntimes
        self.depth_chunk_size = self.ntimes

        logging.info(
            f"The loaded dataset contains {self.nlatitudes} latitudes, {self.nlongitudes} longitudes, {self.ndepths} depths and {self.ntimes} time points"
        )

        # We store an index which maps a linear index of valid measures
        # (i.e. excluding points where there is no phytoplankton record)
        # to the byte offset in the data file
        self._index_file = None
        self._index_lock = Lock()
        self._load_index(filepath.name, overwrite_index)

        # The subset map is a filehandler to a file
        # where the i-th row contains the i-th index in the original dataset
        self._subset_map = None
        if subset_file is not None:
            self._subset_lock = Lock()
            self._load_subset_index(subset_file)

        self.transform = transform
        self.target_transform = target_transform

    def _generate_index(self, indexpath):
        """
        Generate an index for the binary data file
        which maps the linear index to the original data file offset
        and 4D indices of valid measures
        Note we only index valid measures, i.e. excluding coast, etc..

        Generating the index takes around 2 minutes for 50M samples
        """

        fp = open(self.filepath, "rb")
        with open(indexpath, "wb") as fhindex:
            linear_index = 0
            # Rewind the file handler to just after the header
            fp.seek(self.header_offset, os.SEEK_SET)
            for ilatitude in tqdm.tqdm(range(self.nlatitudes)):
                for ilongitude in range(self.nlongitudes):
                    for idepth in range(self.ndepths):
                        # No need to iterate over time
                        # only latitude x longitude x depth
                        # specify a valid or invalid location

                        t0_offset = (
                            self.header_offset
                            + (
                                ilatitude * self.lat_chunk_size
                                + ilongitude * self.lon_chunk_size
                                + idepth * self.depth_chunk_size
                            )
                            * self.row_size
                        )
                        fp.seek(t0_offset, os.SEEK_SET)
                        # Just read the 'valid' field
                        (is_valid,) = struct.unpack(
                            "<?", fp.read(struct.calcsize("<?"))
                        )

                        # If the location is valid
                        # we record all the time samples
                        if is_valid:
                            for dt in range(
                                0, self.ntimes - self.num_days + 1, self.num_days
                            ):
                                fileoffset = t0_offset + dt * self.row_size

                                fmt = (
                                    _ENCODING_ENDIAN
                                    + (_ENCODING_LINEAR * 2)
                                    + (_ENCODING_INDEX * 4)
                                )
                                write_bin_data(
                                    fhindex,
                                    fmt,
                                    (
                                        linear_index,
                                        fileoffset,
                                        ilatitude,
                                        ilongitude,
                                        idepth,
                                        dt,
                                    ),
                                )

                                linear_index += 1
        fp.close()

    def _load_index(self, basename: str, overwrite_index: bool):
        """
        Loads (and possibly compute) an index file
        to convert a linear index idx to its corresponding
        time, depth, latitude, longitude indices

        Arguments:

        """
        indexpath = pathlib.Path(".") / f"{basename}_index.idx"
        # Generate the index if necessary or requested
        if not indexpath.exists() or overwrite_index:
            logging.info("Generating the index")
            self._generate_index(indexpath)
        # And then load the index
        logging.info(f"Loading the index from {indexpath}")
        self._index_file = open(indexpath, "rb")

    def _load_subset_index(self, subsetpath):
        self._subset_map = open(subsetpath, "rb")

    def _get_fileoffset(self, idx):
        fmt = _ENCODING_ENDIAN + (_ENCODING_LINEAR * 2) + (_ENCODING_INDEX * 4)
        whence = 0 if idx >= 0 else 2
        offset = idx * struct.calcsize(fmt)

        (values, _) = read_bin_data(
            self._index_file, offset, whence, fmt, self._index_lock
        )
        linidx = values[0]
        file_offset = values[1]
        tab_indices = values[2:]
        return linidx, file_offset, tab_indices

    def __getitem__(self, idx):
        """
        Access the i-th measure
        Return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?) and the phyc to predict
        """

        # If we are processing a subet, convert the "idx" to the
        # original dataset index
        if self._subset_map is not None:
            fmt = _ENCODING_ENDIAN + _ENCODING_LINEAR
            offset = idx * struct.calcsize(fmt)
            (values, _) = read_bin_data(
                self._subset_map, offset, os.SEEK_SET, fmt, self._subset_lock
            )
            idx = values[0]

        # File offset for the i-th sample
        _, file_offset, (ilatitude, ilongitude, idepth, itime) = self._get_fileoffset(
            idx
        )
        values, _ = read_bin_data(
            self.fp,
            file_offset,
            os.SEEK_SET,
            _ENCODING_ENDIAN + (self.row_format * self.num_days),
            self.fp_lock,
        )

        values = np.array(values).reshape((self.num_days, -1))
        assert np.all(
            values[:, 0]
        )  # This is always expected to be True since only valid
        # locations can be indexed with our generated index map

        # Preprend the latitude, longitude, depth and time
        # TODO Take the time to think about whether this is relevant !
        coordinates = np.ones((self.num_days, 4))
        # print(coordinates.shape, self.times[itime : (itime + self.num_days)])
        coordinates[:, 0] *= self.latitudes[ilatitude]
        coordinates[:, 1] *= self.longitudes[ilongitude]
        coordinates[:, 2] *= self.depths[idepth]
        coordinates[:, 3] = self.times[itime : (itime + self.num_days)]

        # Stick the values with the coordinates
        # This np array is now (num_days , num_coordinates + num_values)
        values = np.hstack([coordinates, values[:, 1:]])

        if self.train:
            in_values = torch.Tensor(values[:, :-1])
            out_values = torch.Tensor(values[:, -1])

            if self.transform is not None:
                in_values = self.transform(in_values)

            if self.target_transform is not None:
                out_values = self.target_transform(out_values)

            return in_values, out_values
        else:
            in_values = torch.Tensor(values[:, :-1])
            if self.transform is not None:
                in_values = self.transform(in_values)
            return in_values

    def __len__(self):
        """
        Returns the total number of valid datapoints for this dataset

        This corresponds to the length of the index or subset index
        """
        # If we are processing a subset, the size of the dataset
        # is the size of the subset
        if self._subset_map is not None:
            subset_size = os.path.getsize(self._subset_map.name)
            return subset_size // struct.calcsize(_ENCODING_ENDIAN + _ENCODING_LINEAR)
        else:
            # Access the last of the index file
            # and get its linear index. This linear index is also
            # the total number of "rows" minus 1 in the file
            lin_idx, _, _ = self._get_fileoffset(-1)
            return lin_idx + 1



def write_bin_data(fp, fmt, values):
    fp.write(struct.pack(fmt, *values))


def read_bin_data(fp, offset, whence, fmt, lock):
    # Unfortunately does not work
    # If it worked, that would prevent to open/close
    # the file repeatidly
    # with lock:
    with open(fp.name, "rb") as fp:
        # values = struct.unpack_from(fmt, localfp, offset)
        fp.seek(offset, whence)
        nbytes = struct.calcsize(fmt)
        values = struct.unpack(fmt, fp.read(nbytes))
    return values, nbytes