Skip to content
Snippets Groups Projects
main.py 1.05 KiB
Newer Older
import dataloader
import model
import test
import train
import yaml

if __name__ == "__main__":
    config_file = open("config.yml")
    cfg = yaml.load(config_file)

    use_cuda = torch.cuda.is_available()
    trainpath           = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
    num_days            = cfg["Dataset"]["num_days"]
    batch_size          = cfg["Dataset"]["batch_size"]
    num_workers         = cfg["Dataset"]["num_workers"]
    valid_ratio         = cfg["Dataset"]["valid_ratio"]
    max_num_samples     = cfg["Dataset"]["max_num_samples"]

    train_loader, valid_loader = dataloader.get_dataloaders(
        trainpath,
        num_days,
        batch_size,
        num_workers,
        use_cuda,
        valid_ratio,
        overwrite_index=True,
        max_num_samples=max_num_samples,
    )

    model = model.build_model(cfg, input_size)
Yandi's avatar
Yandi committed

    
    logdir = generate_unique_logpath(top_logdir, "linear")
    print("Logging to {}".format(logdir))
    # -> Prints out     Logging to   ./logs/linear_1
    if not os.path.exists(logdir):
        os.mkdir(logdir)