Skip to content
Snippets Groups Projects
Commit c9b529d1 authored by Yandi's avatar Yandi
Browse files

cnn with dropout

parent 82be4e7a
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -40,7 +40,7 @@ Training:
#Model selection
Model:
Name: BidirectionalLSTM
Name: CNN1D
#choose in {LinearRegression, BidirectionalLSTM, RNN}
#Model parameters selection
......@@ -64,6 +64,9 @@ RNN:
Dropout: 0.2
Initialization: None
CNN1D:
Initialization: None
#Name of directory containing logs
LogDir: ./logs/
......
This diff is collapsed.
......@@ -5,8 +5,6 @@ import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable
import torch.nn as nn
# Linear Regression (Feed forward Network)
class LinearRegression(nn.Module):
def __init__(self, cfg, input_size):
......@@ -115,6 +113,50 @@ class BidirectionalLSTM(nn.Module):
result = self.fc(out)
return result
# CNN
class CNN1D(torch.nn.Module):
def __init__(self, cfg, num_inputs):
super(CNN1D, self).__init__()
self.model = torch.nn.Sequential(
*conv_block(num_inputs, 5, 0.01),
*conv_block(32, 6, 0.01)
)
self.avg_pool = torch.nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32,cfg["Dataset"]["num_days"])
)
def forward(self, x):
x = torch.transpose(x, 1, 2)
out = self.model(x)
out = self.avg_pool(out)
out = out.view([out.shape[0], -1])
out = self.fc(out)
out = out.view([out.shape[0], out.shape[1], 1])
return out
def conv_block(in_channels, power, dropout_p):
return [
torch.nn.Conv1d(in_channels, 2**power, 16),
torch.nn.BatchNorm1d(2**power),
torch.nn.ReLU(),
torch.nn.Dropout(p=dropout_p),
#torch.nn.Conv1d(2**power, 2**power, 8),
#torch.nn.BatchNorm1d(2**power),
#torch.nn.ReLU(),
#torch.nn.Dropout(p=dropout_p),
torch.nn.MaxPool1d(2, stride = 1)
]
# Initialization
def init_he(module):
if type(module)==nn.Linear:
......
......@@ -35,10 +35,10 @@ def test(model, loader, f_loss, device):
# Compute the forward pass, i.e. the scores for each input image
outputs = model(inputs)
print("Validation inputs :")
print(inputs)
print("Validation outputs :")
print(outputs)
#print("Validation inputs :")
#print(inputs)
#print("Validation outputs :")
#print(outputs)
# We accumulate the exact number of processed samples
N += inputs.shape[0]
......
......@@ -30,6 +30,8 @@ def train(args, model, loader, f_loss, optimizer, device, log_interval = 100):
inputs, targets = inputs.to(device), targets.to(device)
# Compute the forward pass through the network up to the loss
# target's shape is (B, Num_days, 1)
outputs = model(inputs)
loss = f_loss(outputs, targets)
......
No preview for this file type
No preview for this file type
run-20230204_010308-1aksu4p8/logs/debug-internal.log
\ No newline at end of file
run-20230204_013952-5w9xw0aw/logs/debug-internal.log
\ No newline at end of file
run-20230204_010308-1aksu4p8/logs/debug.log
\ No newline at end of file
run-20230204_013952-5w9xw0aw/logs/debug.log
\ No newline at end of file
run-20230204_010308-1aksu4p8
\ No newline at end of file
run-20230204_013952-5w9xw0aw
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment