diff --git a/main.py b/main.py
index 5ff24ec5e944e87d2d88c6a0a37e83d671f04297..abb6709a9c95724f1a0943ce256a1408eb16e0ba 100644
--- a/main.py
+++ b/main.py
@@ -16,14 +16,33 @@ import logging
 import torch.optim
 import torch.nn as nn
 import os
+import argparse
 
 def optimizer(cfg, network):
     result = {"Adam" : torch.optim.Adam(network.parameters())}
     return result[cfg["Optimizer"]]
 
 if __name__ == "__main__":
-    logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
+    parser = argparse.ArgumentParser()
 
+    parser.add_argument(
+    "--no_wandb",
+    action="store_true",
+    help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
+    )
+
+    parser.add_argument(
+    "--rootDir",
+    help="Directory in which the log files will be stored"
+    )
+
+    args = parser.parse_args()
+
+
+    rootDir = cfg["LogDir"] if eval(args.rootDir) != None else args.rootDir
+
+    logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
+    
     config_file = open("config.yml")
     cfg = yaml.load(config_file)
 
@@ -39,12 +58,12 @@ if __name__ == "__main__":
     log_freq        = int(cfg["Wandb"]["log_freq"])
     log_interval    = int(cfg["Wandb"]["log_interval"])
 
+    if not args.no_wandb:
+        wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
 
-    wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
-
+    # Re-compute the statistics or use the stored ones
     approx_stats = cfg["ApproximativeStats"]
 
-    print(approx_stats)
 
     if approx_stats:
         MEAN    = eval(cfg["ApproximativeMean"])
@@ -73,8 +92,6 @@ if __name__ == "__main__":
         valid_ratio,
         overwrite_index =  True,
         max_num_samples=max_num_samples,
-        #train_transform = dataloader.transform_remove_space_time(),
-        #valid_transform = dataloader.transform_remove_space_time()
         train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)),
         valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
     )
@@ -94,33 +111,24 @@ if __name__ == "__main__":
 
     optimizer = optimizer(cfg, network)
 
-    logdir, raw_run_name = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"])
-    wandb.run.name = raw_run_name
+    logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
     network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
 
-
-    wandb.watch(network, log_freq = log_freq)
+    if not args.no_wandb:
+        wandb.run.name = raw_run_name
+        wandb.watch(network, log_freq = log_freq)
 
     for t in range(cfg["Training"]["Epochs"]):
-        torch.autograd.set_detect_anomaly(True)
         print("Epoch {}".format(t))
-        train(network, train_loader, f_loss, optimizer, device, log_interval)
-
+        train(args, network, train_loader, f_loss, optimizer, device, log_interval)
 
-        #print(list(network.parameters())[0].grad)
         val_loss = test.test(network, valid_loader, f_loss, device)
 
         network_checkpoint.update(val_loss)
 
         print(" Validation : Loss : {:.4f}".format(val_loss))
-        wandb.log({"val_loss": val_loss})
+        if not args.no_wandb:
+            wandb.log({"val_loss": val_loss})
 
 
     create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device)
-    """
-    logdir = generate_unique_logpath(top_logdir, "linear")
-    print("Logging to {}".format(logdir))
-    # -> Prints out     Logging to   ./logs/linear_1
-    if not os.path.exists(logdir):
-        os.mkdir(logdir)
-    """
\ No newline at end of file