From 84499acdea9f8a070b5bb2d1a399ead71fca07af Mon Sep 17 00:00:00 2001
From: Yandi <yandirzm@gmail.com>
Date: Sat, 21 Jan 2023 20:31:50 +0100
Subject: [PATCH] [Dev] Getting used to the dataset architecture

---
 dataloader.py | 56 ++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 49 insertions(+), 7 deletions(-)

diff --git a/dataloader.py b/dataloader.py
index 355a91d..d45ab4d 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -186,7 +186,7 @@ if __name__ == "__main__":
     logging.info("====> Test dataloader")
     use_cuda = torch.cuda.is_available()
     trainpath = _DEFAULT_TRAIN_FILEPATH
-    num_days = 1  # Test with sequence of 1 day
+    num_days = 35  # Test with sequence of 1 day
     batch_size = 128
     num_workers = 7
     valid_ratio = 0.2
@@ -194,16 +194,58 @@ if __name__ == "__main__":
     max_num_samples = None
 
     train_loader, valid_loader = get_dataloaders(
-        trainpath,
-        num_days,
-        batch_size,
-        num_workers,
-        use_cuda,
-        valid_ratio,
+        filepath = trainpath,
+        num_days = num_days,
+        batch_size = batch_size,
+        num_workers = num_workers,
+        pin_memory = False,
+        valid_ratio = valid_ratio,
         overwrite_index=True,
         max_num_samples=max_num_samples,
     )
 
     it = iter(train_loader)
     X, Y = next(it)
+
+    def check_min(tensor, index):
+        """
+        For a tensor of shape (B, T, N) return the min value of N[index]
+        """
+        mini = 1e99
+        for batch in tensor:
+            for value_time in batch:
+                value = value_time[index].item()
+                if value < mini :
+                    mini = value
+        print(f"Min value for index {index} is ")
+        return mini
+
+    def check_max(tensor, index):
+        """
+        For a tensor of shape (B, T, N) return the max value of N[index]
+        """
+        maxi = -1e99
+        for batch in tensor:
+            for value_time in batch:
+                value = value_time[index].item()
+                if value > maxi :
+                    maxi = value
+        print(f"Max value for index {index} is ")
+        return maxi
+
+    def check_info(tensor, index):
+        print("="*30)
+        print(check_min(tensor, index))
+        print(check_max(tensor, index))
+        print("="*30)
+
+    check_info(X,0) #latitude
+    check_info(X,1) #longitude
+    check_info(X,2) #depth
+    check_info(X,3) #time
+
+    """
+    Size of X = (Batchsize, Number of days, Features (18) (lat, long, depth, time, features))
+    Size of Y = (Batchsize, Number of days)
+    """
     logging.info(f"Got a minibatch of size {X.shape} -> {Y.shape}")
\ No newline at end of file
-- 
GitLab