dataset.py 13.10 KiB
# Standard imports
import os
import sys
import logging
import pathlib
from typing import Union
import struct
from datetime import datetime
# import multiprocessing
from multiprocessing import Lock
# from multiprocessing.synchronize import Lock
# External imports
import numpy as np
import tqdm
import torch
import torch.utils.data as data
_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
_DEFAULT_TEST_FILEPATH = (
"/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
)
_ENCODING_LINEAR = "I"
_ENCODING_INDEX = "I" # h(short) with 2 bytes should be sufficient
_ENCODING_OFFSET_FORMAT = ""
_ENCODING_ENDIAN = "<"
class Dataset(data.Dataset):
"""
Dataset for training a predictor predicting the phytoplankton
density from (a series of) environmental variables
+ time + depth + latitude + longitude
This dataset contains 50 154 984 valid data points
sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin
Generating the index takes approximately 20 seconds
"""
def __init__(
self,
filepath: Union[pathlib.Path, str],
overwrite_index: bool = False,
train=True,
subset_file=None,
num_days=1,
transform=None,
target_transform=None,
):
"""
Builds a pointdataset generating the index if necessary or requested
Arguments:
filepath: the full path to the nc file to load
overwrite_index: if True ignores the index and regenerates it
train: if True, accessing an element also returns the phyc
subset_file: a filename which holds a list of indices this dataset must use
num_days: the number of days that each sample considers
transform: a transform to apply to the input tensor
target_transform: a transform to apply to the phyc output tensor
"""
super().__init__()
if isinstance(filepath, str):
filepath = pathlib.Path(filepath)
self.filepath = filepath
# Load the dataset. It can be indexed by [lat, long, depth, time]
self.fp = open(filepath, "rb")
# self.fp.close()
self.fp_lock = Lock()
# List all the environmental variables
# To which we could add time, depth, latitude, longitude
self.in_variables = [
"dissic",
"mlotst",
"nh4",
"no3",
"nppv",
"o2",
"ph",
"po4",
"so",
"talk",
"thetao",
"uo",
"vo",
"zos",
# "phyc",
]
self.train = train
if self.train:
self.out_variable = "phyc"
# Store the size of the timeseries, in days
self.num_days = num_days
# Set up the format of a row
self.row_format = (
"?" + "f" * len(self.in_variables) + ("f" if self.train else "")
)
self.row_size = struct.calcsize(_ENCODING_ENDIAN + self.row_format)
# Load the header of the file for the dimension variables
# Local utilitary function to parse the header
def _read_dim(fp, offset, base_format, lock):
fmt = _ENCODING_ENDIAN + "i"
(dim, nbytes_dim) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
dim = dim[0] # the returned values is a tuple
offset += nbytes_dim
fmt = _ENCODING_ENDIAN + (base_format * dim)
(values, nbytes_values) = read_bin_data(fp, offset, os.SEEK_SET, fmt, lock)
return dim, np.array(values), nbytes_dim + nbytes_values
self.header_offset = 0
self.nlatitudes, self.latitudes, nbytes = _read_dim(
self.fp, self.header_offset, "f", self.fp_lock
)
self.header_offset += nbytes
self.nlongitudes, self.longitudes, nbytes = _read_dim(
self.fp, self.header_offset, "f", self.fp_lock
)
self.header_offset += nbytes
self.ndepths, self.depths, nbytes = _read_dim(
self.fp, self.header_offset, "f", self.fp_lock
)
self.header_offset += nbytes
self.ntimes, self.times, nbytes = _read_dim(
self.fp, self.header_offset, "i", self.fp_lock
)
# If we need to convert timestamps to datetimes
# self.dtimes = np.array([datetime.fromtimestamp(di) for di in dtimes.tolist()])
self.header_offset += nbytes
self.lat_chunk_size = self.nlongitudes * self.ndepths * self.ntimes
self.lon_chunk_size = self.ndepths * self.ntimes
self.depth_chunk_size = self.ntimes
logging.info(
f"The loaded dataset contains {self.nlatitudes} latitudes, {self.nlongitudes} longitudes, {self.ndepths} depths and {self.ntimes} time points"
)
# We store an index which maps a linear index of valid measures
# (i.e. excluding points where there is no phytoplankton record)
# to the byte offset in the data file
self._index_file = None
self._index_lock = Lock()
self._load_index(filepath.name, overwrite_index)
# The subset map is a filehandler to a file
# where the i-th row contains the i-th index in the original dataset
self._subset_map = None
if subset_file is not None:
self._subset_lock = Lock()
self._load_subset_index(subset_file)
self.transform = transform
self.target_transform = target_transform
def _generate_index(self, indexpath):
"""
Generate an index for the binary data file
which maps the linear index to the original data file offset
and 4D indices of valid measures
Note we only index valid measures, i.e. excluding coast, etc..
Generating the index takes around 2 minutes for 50M samples
"""
fp = open(self.filepath, "rb")
with open(indexpath, "wb") as fhindex:
linear_index = 0
# Rewind the file handler to just after the header
fp.seek(self.header_offset, os.SEEK_SET)
for ilatitude in tqdm.tqdm(range(self.nlatitudes)):
for ilongitude in range(self.nlongitudes):
for idepth in range(self.ndepths):
# No need to iterate over time
# only latitude x longitude x depth
# specify a valid or invalid location
t0_offset = (
self.header_offset
+ (
ilatitude * self.lat_chunk_size
+ ilongitude * self.lon_chunk_size
+ idepth * self.depth_chunk_size
)
* self.row_size
)
fp.seek(t0_offset, os.SEEK_SET)
# Just read the 'valid' field
(is_valid,) = struct.unpack(
"<?", fp.read(struct.calcsize("<?"))
)
# If the location is valid
# we record all the time samples
if is_valid:
for dt in range(
0, self.ntimes - self.num_days + 1, self.num_days
):
fileoffset = t0_offset + dt * self.row_size
fmt = (
_ENCODING_ENDIAN
+ (_ENCODING_LINEAR * 2)
+ (_ENCODING_INDEX * 4)
)
write_bin_data(
fhindex,
fmt,
(
linear_index,
fileoffset,
ilatitude,
ilongitude,
idepth,
dt,
),
)
linear_index += 1
fp.close()
def _load_index(self, basename: str, overwrite_index: bool):
"""
Loads (and possibly compute) an index file
to convert a linear index idx to its corresponding
time, depth, latitude, longitude indices
Arguments:
"""
indexpath = pathlib.Path(".") / f"{basename}_index.idx"
# Generate the index if necessary or requested
if not indexpath.exists() or overwrite_index:
logging.info("Generating the index")
self._generate_index(indexpath)
# And then load the index
logging.info(f"Loading the index from {indexpath}")
self._index_file = open(indexpath, "rb")
def _load_subset_index(self, subsetpath):
self._subset_map = open(subsetpath, "rb")
def _get_fileoffset(self, idx):
fmt = _ENCODING_ENDIAN + (_ENCODING_LINEAR * 2) + (_ENCODING_INDEX * 4)
whence = 0 if idx >= 0 else 2
offset = idx * struct.calcsize(fmt)
(values, _) = read_bin_data(
self._index_file, offset, whence, fmt, self._index_lock
)
linidx = values[0]
file_offset = values[1]
tab_indices = values[2:]
return linidx, file_offset, tab_indices
def __getitem__(self, idx):
"""
Access the i-th measure
Return an array of size (15,)+ time(?) + depth(?)+ lat(?) + lon(?) and the phyc to predict
"""
# If we are processing a subet, convert the "idx" to the
# original dataset index
if self._subset_map is not None:
fmt = _ENCODING_ENDIAN + _ENCODING_LINEAR
offset = idx * struct.calcsize(fmt)
(values, _) = read_bin_data(
self._subset_map, offset, os.SEEK_SET, fmt, self._subset_lock
)
idx = values[0]
# File offset for the i-th sample
_, file_offset, (ilatitude, ilongitude, idepth, itime) = self._get_fileoffset(
idx
)
values, _ = read_bin_data(
self.fp,
file_offset,
os.SEEK_SET,
_ENCODING_ENDIAN + (self.row_format * self.num_days),
self.fp_lock,
)
values = np.array(values).reshape((self.num_days, -1))
assert np.all(
values[:, 0]
) # This is always expected to be True since only valid
# locations can be indexed with our generated index map
# Preprend the latitude, longitude, depth and time
# TODO Take the time to think about whether this is relevant !
coordinates = np.ones((self.num_days, 4))
# print(coordinates.shape, self.times[itime : (itime + self.num_days)])
coordinates[:, 0] *= self.latitudes[ilatitude]
coordinates[:, 1] *= self.longitudes[ilongitude]
coordinates[:, 2] *= self.depths[idepth]
coordinates[:, 3] = self.times[itime : (itime + self.num_days)]
# Stick the values with the coordinates
# This np array is now (num_days , num_coordinates + num_values)
values = np.hstack([coordinates, values[:, 1:]])
if self.train:
in_values = torch.Tensor(values[:, :-1])
out_values = torch.Tensor(values[:, -1])
if self.transform is not None:
in_values = self.transform(in_values)
if self.target_transform is not None:
out_values = self.target_transform(out_values)
return in_values, out_values
else:
in_values = torch.Tensor(values[:, :-1])
if self.transform is not None:
in_values = self.transform(in_values)
return in_values
def __len__(self):
"""
Returns the total number of valid datapoints for this dataset
This corresponds to the length of the index or subset index
"""
# If we are processing a subset, the size of the dataset
# is the size of the subset
if self._subset_map is not None:
subset_size = os.path.getsize(self._subset_map.name)
return subset_size // struct.calcsize(_ENCODING_ENDIAN + _ENCODING_LINEAR)
else:
# Access the last of the index file
# and get its linear index. This linear index is also
# the total number of "rows" minus 1 in the file
lin_idx, _, _ = self._get_fileoffset(-1)
return lin_idx + 1
def write_bin_data(fp, fmt, values):
fp.write(struct.pack(fmt, *values))
def read_bin_data(fp, offset, whence, fmt, lock):
# Unfortunately does not work
# If it worked, that would prevent to open/close
# the file repeatidly
# with lock:
with open(fp.name, "rb") as fp:
# values = struct.unpack_from(fmt, localfp, offset)
fp.seek(offset, whence)
nbytes = struct.calcsize(fmt)
values = struct.unpack(fmt, fp.read(nbytes))
return values, nbytes