diff --git a/config.yml b/config.yml
index dce0a3ccb616b311177376a3e720d802b43ddc2f..ca6f7415598752e48584f85351cba6bff07ef2b2 100644
--- a/config.yml
+++ b/config.yml
@@ -11,11 +11,15 @@ Dataset:
   _ENCODING_INDEX: "I"  # h(short) with 2 bytes should be sufficient
   _ENCODING_OFFSET_FORMAT: ""
   _ENCODING_ENDIAN: "<"
-  Transform: dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_normalize_with_train_statistics(MEAN, STD))
+  Transform: dataloader.composite_transform(dataloader.transform_remove_time(), dataloader.transform_normalize_with_train_statistics(MEAN, STD))
   # Available transforms:
   # dataloader.transform_remove_space_time()
   # dataloader.transform_normalize_with_train_statistics(MEAN, STD)
   # dataloader.transform_min_max_scaling(MIN, MAX)
+  # dataloader.transform_remove_time()
+  # To compose multiple transforms, you can use the function :
+  # dataloader.composite_transform(..., ...)
+  # Example :
   # dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
 
 # Data Transformation
diff --git a/dataloader.py b/dataloader.py
index d312555015cb363fa1f7b0f04e41d6652cc615d7..54a3e52290f6dbdc52d8d3212a58d5cf706f166b 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -22,16 +22,6 @@ import torch.utils.data as data
 # Imports from the project
 from dataset import Dataset
 
-_DEFAULT_TRAIN_FILEPATH = "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin"
-_DEFAULT_TEST_FILEPATH = (
-    "/mounts/Datasets3/2022-ChallengePlankton/sub_2CMEMS-MEDSEA-2017-testing.nc.bin"
-)
-
-_ENCODING_LINEAR = "I"
-_ENCODING_INDEX = "I"  # h(short) with 2 bytes should be sufficient
-_ENCODING_OFFSET_FORMAT = ""
-_ENCODING_ENDIAN = "<"
-
 def train_valid_split(
     dataset,
     valid_ratio,
@@ -263,6 +253,9 @@ def get_test_dataloader(
 def transform_remove_space_time():
     return lambda attributes : attributes[:,4:]
 
+def transform_remove_time():
+    return lambda attributes : torch.cat((attributes[:, :3], attributes[:, 4:]), dim=1)
+
 def transform_normalize_with_train_statistics(MEAN, STD):
     return lambda attributes : torch.div(torch.sub(attributes, MEAN), STD)
 
diff --git a/main.py b/main.py
index f9ca5f6baed6992bde691d512f6ebb3fa004f05a..fe6efdf5863f5b9a0f8c61cf73c33143cb945765 100644
--- a/main.py
+++ b/main.py
@@ -38,6 +38,7 @@ def train(args, cfg):
     log_interval    = int(cfg["Wandb"]["log_interval"])
 
     dataset_transform = cfg["Dataset"]["Transform"]
+    input_size = 14 if "space_time" in dataset_transform else (17 if "time" in dataset_transform else 18)
 
     if not args.no_wandb:
         wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
@@ -82,7 +83,7 @@ def train(args, cfg):
     else :
         device = toch.device('cpu')
 
-    network = model.build_model(cfg, 14)
+    network = model.build_model(cfg, input_size)
 
     network = network.to(device)
 
@@ -148,7 +149,10 @@ def test(args):
 
     model_path = args.PATHTOCHECKPOINT
 
-    network = model.build_model(cfg, 14)
+    dataset_transform = cfg["Dataset"]["Transform"]
+    input_size = 14 if "space_time" in dataset_transform else (17 if "time" in dataset_transform else 18)
+
+    network = model.build_model(cfg, input_size)
 
     network = model.to(device)
 
diff --git a/utils.py b/utils.py
index b4293545ee08b7692043b2342dbd9684eaac345a..4361c3206fdf2ec52daffd7bb618782fbdc75933 100644
--- a/utils.py
+++ b/utils.py
@@ -59,3 +59,123 @@ def write_summary(logdir, model, optimizer, val_loss):
     summary_file.close()
 
 
+
+def create_submission(args, model, transform, device, rootDir, logdir):
+    cfg = yaml.load(config_file, Loader=yaml.FullLoader)
+    step_days = 10
+    batch_size = 1024
+    # We make chunks of num_days consecutive samples; As our dummy predictor
+    # is not using the temporal context, this is here arbitrarily chosen
+    # However, note that it must be a divisor of the total number of days
+    # in the 2017 year , either 1, 5, 73 or 365
+    num_days = cfg["Dataset"]["num_days"]
+    num_workers = 7
+
+    use_cuda = torch.cuda.is_available()
+    # Build the dataloaders
+    logging.info("Building the dataloader")
+
+    if args.PATHTOTESTSET != None:
+            test_loader = dataloader.get_test_dataloader(
+        args.PATHTOTESTSET,
+        num_days,
+        batch_size,
+        num_workers,
+        use_cuda,
+        overwrite_index=True,
+        transform=transform,
+        target_transform=None,
+    )
+    else :
+        test_loader = dataloader.get_test_dataloader(
+            dataloader._DEFAULT_TEST_FILEPATH,
+            num_days,
+            batch_size,
+            num_workers,
+            use_cuda,
+            overwrite_index=True,
+            transform=transform,
+            target_transform=None,
+        )
+    num_days_test = test_loader.dataset.ntimes
+
+    logging.info("= Filling in the submission file")
+    with open(logdir + "submission.csv", "w") as fh_submission:
+        fh_submission.write("Id,Predicted\n")
+        submission_offset = 0
+
+        # Iterate on the test dataloader
+        t_offset = 0
+        # Every minibatch will contain batch_size * num_days
+        # As we do not shuffle the data these correspond to consecutive
+        # days of the same location then followed by consecutive days of the
+        # next location and so on
+        chunk_size = batch_size * num_days
+        with torch.no_grad():
+            for X in tqdm.tqdm(test_loader):
+                X = X.to(device)
+                #############################################
+                # This is where you inject your knowledge
+                # About your model
+                # The rest of the code is generic as soon as you have a
+                # model working on time series
+                # X is (B, T, N)
+                # predictions are (B, T)
+                predictions = model(X)
+
+                #############################################
+
+                # we reshape it in (B * T)
+                # and keep only the time instants we need
+                predictions = predictions.view(-1)
+
+                # we need to slice the times by steps of days
+                # in chunks of num_test_days days (2017 had  365 days)
+                yearcut_indices = list(range(0, chunk_size + t_offset, num_days_test))
+                # The yearcut_indices are the indices in the linearized minibatch
+                # corresponding to the 01/01/2017 for some (latitude, longitude, depth)
+                # For these yearcut_indices, we can locate where to sample
+                # The vector of predictions
+                subdays_indices = [
+                    y + k
+                    for y in yearcut_indices
+                    for k in range(0, num_days_test, step_days)
+                ]
+                subdays_indices = list(map(lambda i: i - t_offset, subdays_indices))
+
+                # Remove the negative indices if any
+                # These negatives indices happen because of the offset
+                # These correspond to the locations of the 01/01/2017 in the previous
+                # minibatch
+                subdays_indices = [
+                    k
+                    for k in subdays_indices
+                    if 0 <= k < min(chunk_size, predictions.shape[0])
+                ]
+                t_offset = chunk_size - (yearcut_indices[-1] - t_offset)
+
+                predictions_list = predictions[subdays_indices].tolist()
+
+                # Check
+                # X = X.view(-1, 18)
+                # subX = X[yearcut_indices, :]
+                # # subX = X
+                # timestamps = subX[:, 3].tolist()
+                # print(
+                #     "\n".join(
+                #         [f"{datetime.datetime.fromtimestamp(x)}" for x in timestamps]
+                #     )
+                # )
+                # print("\n\n")
+                # sys.exit(-1)
+
+                # Dump the predictions to the submission file
+                submission_part = "\n".join(
+                    [
+                        f"{i+submission_offset},{pred}"
+                        for i, pred in enumerate(predictions_list)
+                    ]
+                )
+                fh_submission.write(submission_part + "\n")
+                submission_offset += len(predictions_list)
+        fh_submission.close()
\ No newline at end of file