Skip to content
Snippets Groups Projects
Commit 8157f9fb authored by Yandi's avatar Yandi
Browse files

[Second][FFN] Simple

parent c7e1483e
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -9,6 +9,7 @@ import torch
import logging
import torch.optim
import torch.nn as nn
import create_submission
def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters())}
......@@ -36,7 +37,9 @@ if __name__ == "__main__":
use_cuda,
valid_ratio,
overwrite_index=True,
max_num_samples=max_num_samples
max_num_samples=max_num_samples,
train_transform=dataloader.transform_remove_space_time(),
valid_transform=dataloader.transform_remove_space_time()
)
if use_cuda :
......@@ -47,8 +50,13 @@ if __name__ == "__main__":
#model = model.build_model(cfg, 18)
model = nn.Sequential(
nn.Linear(18,1,False),
nn.ReLU()
nn.Linear(14,8,False),
nn.ReLU(),
nn.Linear(8, 8, True),
nn.ReLU(),
nn.Linear(8,35,True),
nn.ReLU(),
nn.Linear(35,1, True)
)
model = model.to(device)
......@@ -62,14 +70,16 @@ if __name__ == "__main__":
for t in range(cfg["Training"]["Epochs"]):
torch.autograd.set_detect_anomaly(True)
print("Epoch {}".format(t))
train(model, train_loader, f_loss, optimizer, device)
print(list(model.parameters())[0].grad)
#print(list(model.parameters())[0].grad)
val_loss = test.test(model, valid_loader, f_loss, device)
print(" Validation : Loss : {:.4f}".format(val_loss))
create_submission.create_submission(model)
"""
logdir = generate_unique_logpath(top_logdir, "linear")
print("Logging to {}".format(logdir))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment