diff --git a/main.py b/main.py index 30c623069f684fafa2b055df630ce41d9ec9239e..bfd0fe0c3315179ad6856fd9c0ac32ee6dba1b26 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,10 @@ import dataloader import model import test -import train +from train import train import yaml +import losses +import models if __name__ == "__main__": config_file = open("config.yml") @@ -27,11 +29,25 @@ if __name__ == "__main__": max_num_samples=max_num_samples, ) + if use_cuda : + device = torch.device('cuda') + else : + device = toch.device('cpu') + model = model.build_model(cfg, input_size) + f_loss = losses.RMSLE.RMSLE() + + optimizer = models.choose_optimizer.optimizer(cfg) + + train(model = model, loader = train_loader, f_loss = f_loss, optimizer = optimizer, device = device) + + + """ logdir = generate_unique_logpath(top_logdir, "linear") print("Logging to {}".format(logdir)) # -> Prints out Logging to ./logs/linear_1 if not os.path.exists(logdir): - os.mkdir(logdir) \ No newline at end of file + os.mkdir(logdir) + """ \ No newline at end of file