Skip to content
Snippets Groups Projects
Commit f1b4e588 authored by Yandi's avatar Yandi
Browse files

[Submissions] Copying script from Kaggle

parent ec5f4bfa
No related branches found
No related tags found
1 merge request!1Master into main
# coding: utf-8
"""
This a dummy test with a non sense model in order to illustrate how on could apply a model to produce predictions as expected by the submission
The main algorithmic difficulties illustrated by this script are :
- continuously iterating over the test_dataset without shuffling will
continuously iteration over the volmue (latitude, longitude, depth, time)
- when iterating over minibatches, to subsample every time series by step of
10 days, we show how to identify where to sample the minibatches of predictions
"""
# Standard imports
import sys
import logging
import datetime
# External imports
import tqdm
import torch
# Local imports
import bindataset as dataset
def create_submission(model):
step_days = 10
batch_size = 1024
# We make chunks of num_days consecutive samples; As our dummy predictor
# is not using the temporal context, this is here arbitrarily chosen
# However, note that it must be a divisor of the total number of days
# in the 2017 year , either 1, 5, 73 or 365
num_days = 365
num_workers = 7
use_cuda = torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device("cpu")
# Build the dataloaders
logging.info("Building the dataloader")
test_loader = dataset.get_test_dataloader(
dataset._DEFAULT_TEST_FILEPATH,
num_days,
batch_size,
num_workers,
use_cuda,
overwrite_index=True,
transform=transform,
target_transform=None,
)
num_days_test = test_loader.dataset.ntimes
logging.info("= Filling in the submission file")
with open("submission.csv", "w") as fh_submission:
fh_submission.write("Id,Predicted\n")
submission_offset = 0
# Iterate on the test dataloader
t_offset = 0
# Every minibatch will contain batch_size * num_days
# As we do not shuffle the data these correspond to consecutive
# days of the same location then followed by consecutive days of the
# next location and so on
chunk_size = batch_size * num_days
with torch.no_grad():
for X in tqdm.tqdm(test_loader):
X.to(device)
#############################################
# This is where you inject your knowledge
# About your model
# The rest of the code is generic as soon as you have a
# model working on time series
# X is (B, T, N)
# predictions are (B, T)
predictions = model(X)
#############################################
# we reshape it in (B * T)
# and keep only the time instants we need
predictions = predictions.view(-1)
# we need to slice the times by steps of days
# in chunks of num_test_days days (2017 had 365 days)
yearcut_indices = list(range(0, chunk_size + t_offset, num_days_test))
# The yearcut_indices are the indices in the linearized minibatch
# corresponding to the 01/01/2017 for some (latitude, longitude, depth)
# For these yearcut_indices, we can locate where to sample
# The vector of predictions
subdays_indices = [
y + k
for y in yearcut_indices
for k in range(0, num_days_test, step_days)
]
subdays_indices = list(map(lambda i: i - t_offset, subdays_indices))
# Remove the negative indices if any
# These negatives indices happen because of the offset
# These correspond to the locations of the 01/01/2017 in the previous
# minibatch
subdays_indices = [
k
for k in subdays_indices
if 0 <= k < min(chunk_size, predictions.shape[0])
]
t_offset = chunk_size - (yearcut_indices[-1] - t_offset)
predictions_list = predictions[subdays_indices].tolist()
# Check
# X = X.view(-1, 18)
# subX = X[yearcut_indices, :]
# # subX = X
# timestamps = subX[:, 3].tolist()
# print(
# "\n".join(
# [f"{datetime.datetime.fromtimestamp(x)}" for x in timestamps]
# )
# )
# print("\n\n")
# sys.exit(-1)
# Dump the predictions to the submission file
submission_part = "\n".join(
[
f"{i+submission_offset},{pred}"
for i, pred in enumerate(predictions_list)
]
)
fh_submission.write(submission_part + "\n")
submission_offset += len(predictions_list)
fh_submission.close()
if __name__ == "__main__":
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")
test()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment