From a2359870697ae0438568e6202a7607dc8c021fef Mon Sep 17 00:00:00 2001 From: Yandi <yandirzm@gmail.com> Date: Sat, 21 Jan 2023 20:52:33 +0100 Subject: [PATCH] [Run] Storing before testing main --- main.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 30c6230..bfd0fe0 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,10 @@ import dataloader import model import test -import train +from train import train import yaml +import losses +import models if __name__ == "__main__": config_file = open("config.yml") @@ -27,11 +29,25 @@ if __name__ == "__main__": max_num_samples=max_num_samples, ) + if use_cuda : + device = torch.device('cuda') + else : + device = toch.device('cpu') + model = model.build_model(cfg, input_size) + f_loss = losses.RMSLE.RMSLE() + + optimizer = models.choose_optimizer.optimizer(cfg) + + train(model = model, loader = train_loader, f_loss = f_loss, optimizer = optimizer, device = device) + + + """ logdir = generate_unique_logpath(top_logdir, "linear") print("Logging to {}".format(logdir)) # -> Prints out Logging to ./logs/linear_1 if not os.path.exists(logdir): - os.mkdir(logdir) \ No newline at end of file + os.mkdir(logdir) + """ \ No newline at end of file -- GitLab