Newer
Older
import dataloader
import model
import test
import create_submission
import utils
#External imports
import yaml
def optimizer(cfg, network):
result = {"Adam" : torch.optim.Adam(network.parameters())}
parser.add_argument(
"--no_wandb",
action="store_true",
help="If specified, no log will be sent to wandb. Especially useful when running batch jobs.",
)
parser.add_argument(
"--rootDir",
help="Directory in which the log files will be stored"
)
args = parser.parse_args()
config_file = open("config.yml")
cfg = yaml.load(config_file)
rootDir = args.rootDir if args.rootDir != None else cfg["LogDir"]
logging.basicConfig(filename= rootDir + 'main_unit_test.log', level=logging.INFO)
use_cuda = torch.cuda.is_available()
trainpath = cfg["Dataset"]["_DEFAULT_TRAIN_FILEPATH"]
num_days = int(cfg["Dataset"]["num_days"])
batch_size = int(cfg["Dataset"]["batch_size"])
num_workers = int(cfg["Dataset"]["num_workers"])
valid_ratio = float(cfg["Dataset"]["valid_ratio"])
max_num_samples = eval(cfg["Dataset"]["max_num_samples"])
epochs = int(cfg["Training"]["Epochs"])
log_freq = int(cfg["Wandb"]["log_freq"])
log_interval = int(cfg["Wandb"]["log_interval"])
if not args.no_wandb:
wandb.init(entity = "wherephytoplankton", project = "Kaggle phytoplancton", config = {"batch_size": batch_size, "epochs": epochs})
# Re-compute the statistics or use the stored ones
approx_stats = cfg["ApproximativeStats"]
if approx_stats:
MEAN = eval(cfg["ApproximativeMean"])
STD = eval(cfg["ApproximativeSTD"])
MAX = eval(cfg["ApproximativeMaxi"])
MIN = eval(cfg["ApproximativeMini"])
MEAN, STD, MAX, MIN = dataloader.get_stats_train_dataset(trainpath,
num_days,
batch_size,
num_workers,
use_cuda,
valid_ratio,
overwrite_index=True,
max_num_samples=max_num_samples,
train_transform=None,
valid_transform=None
)
train_loader, valid_loader = dataloader.get_dataloaders(
trainpath,
num_days,
batch_size,
num_workers,
use_cuda,
valid_ratio,
train_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)),
valid_transform=dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX))
if use_cuda :
device = torch.device('cuda')
else :
device = toch.device('cpu')
network = model.build_model(cfg, 14)
model.initialize_model(cfg, network)
logdir, raw_run_name = utils.create_unique_logpath(rootDir, cfg["Model"]["Name"])
network_checkpoint = model.ModelCheckpoint(logdir + "/best_model.pt", network)
if not args.no_wandb:
wandb.run.name = raw_run_name
wandb.watch(network, log_freq = log_freq)
for t in range(cfg["Training"]["Epochs"]):
print("Epoch {}".format(t))
train(args, network, train_loader, f_loss, optimizer, device, log_interval)
val_loss = test.test(network, valid_loader, f_loss, device)
network_checkpoint.update(val_loss)
if not args.no_wandb:
wandb.log({"val_loss": val_loss})
create_submission.create_submission(network, dataloader.composite_transform(dataloader.transform_remove_space_time(), dataloader.transform_min_max_scaling(MIN, MAX)), device, rootDir)