Skip to content
Snippets Groups Projects
Commit 28e92ab9 authored by Yandi's avatar Yandi
Browse files

adapted main to jobs and added no_wandb flags

parent 7238cdc6
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -16,14 +16,33 @@ import logging
import torch.optim
import torch.nn as nn
import os
import argparse
def optimizer(cfg, network):
result = {"Adam" : torch.optim.Adam(network.parameters())}
return result[cfg["Optimizer"]]
if __name__ == "__main__":
logging.basicConfig(filename='logs/main_unit_test.log', level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--no_wandb",
action="store_true",
help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
)
parser.add_argument(
"--rootDir",
help="Directory in which the log files will be stored"
)
args = parser.parse_args()
rootDir = cfg["LogDir"] if eval(args.rootDir) != None else args.rootDir
logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
config_file = open("config.yml")
cfg = yaml.load(config_file)
......@@ -39,12 +58,12 @@ if __name__ == "__main__":
log_freq = int(cfg["Wandb"]["log_freq"])
log_interval = int(cfg["Wandb"]["log_interval"])
if not args.no_wandb:
wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
# Re-compute the statistics or use the stored ones
approx_stats = cfg["ApproximativeStats"]
print(approx_stats)
if approx_stats:
MEAN = eval(cfg["ApproximativeMean"])
......@@ -73,8 +92,6 @@ if __name__ == "__main__":
valid_ratio,
overwrite_index = True,
max_num_samples=max_num_samples,
#train_transform = dataloader.transform_remove_space_time(),
#valid_transform = dataloader.transform_remove_space_time()
train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)),
valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
)
......@@ -94,33 +111,24 @@ if __name__ == "__main__":
optimizer = optimizer(cfg, network)
logdir, raw_run_name = utils.create_unique_logpath(cfg["LogDir"], cfg["Model"]["Name"])
wandb.run.name = raw_run_name
logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
wandb.watch(network, log_freq = log_freq)
if not args.no_wandb:
wandb.run.name = raw_run_name
wandb.watch(network, log_freq = log_freq)
for t in range(cfg["Training"]["Epochs"]):
torch.autograd.set_detect_anomaly(True)
print("Epoch {}".format(t))
train(network, train_loader, f_loss, optimizer, device, log_interval)
train(args, network, train_loader, f_loss, optimizer, device, log_interval)
#print(list(network.parameters())[0].grad)
val_loss = test.test(network, valid_loader, f_loss, device)
network_checkpoint.update(val_loss)
print(" Validation : Loss : {:.4f}".format(val_loss))
wandb.log({"val_loss": val_loss})
if not args.no_wandb:
wandb.log({"val_loss": val_loss})
create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device)
"""
logdir = generate_unique_logpath(top_logdir, "linear")
print("Logging to {}".format(logdir))
# -> Prints out Logging to ./logs/linear_1
if not os.path.exists(logdir):
os.mkdir(logdir)
"""
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment