Skip to content
Snippets Groups Projects
Commit 9e3191da authored by Yandi's avatar Yandi
Browse files

deeper cnn

parent 9b430d78
No related branches found
No related tags found
1 merge request!1Master into main
No preview for this file type
...@@ -120,14 +120,14 @@ class CNN1D(torch.nn.Module): ...@@ -120,14 +120,14 @@ class CNN1D(torch.nn.Module):
def __init__(self, cfg, num_inputs): def __init__(self, cfg, num_inputs):
super(CNN1D, self).__init__() super(CNN1D, self).__init__()
self.model = torch.nn.Sequential( self.model = torch.nn.Sequential(
*conv_block(num_inputs, 5, 0.01), *conv_block(num_inputs, 6, 0.01),
*conv_block(32, 6, 0.1), *conv_block(64, 7, 0.01),
*conv_block(64,7,0.01) *conv_block(128,8,0.01)
) )
self.avg_pool = torch.nn.AdaptiveAvgPool1d(1) self.avg_pool = torch.nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(128, 32), nn.Linear(256, 32),
nn.ReLU(), nn.ReLU(),
nn.Linear(32,cfg["Dataset"]["num_days"]) nn.Linear(32,cfg["Dataset"]["num_days"])
) )
...@@ -148,12 +148,12 @@ class CNN1D(torch.nn.Module): ...@@ -148,12 +148,12 @@ class CNN1D(torch.nn.Module):
def conv_block(in_channels, power, dropout_p): def conv_block(in_channels, power, dropout_p):
return [ return [
torch.nn.Conv1d(in_channels, 2**power, 16), torch.nn.Conv1d(in_channels, 2**power, 16),
torch.nn.BatchNorm1d(2**power),
torch.nn.ReLU(),
torch.nn.Dropout(p=dropout_p),
#torch.nn.Conv1d(2**power, 2**power, 8),
#torch.nn.BatchNorm1d(2**power), #torch.nn.BatchNorm1d(2**power),
#torch.nn.ReLU(), torch.nn.LeakyReLU(),
torch.nn.Dropout(p=dropout_p),
torch.nn.Conv1d(2**power, 2**power, 8),
torch.nn.BatchNorm1d(2**power),
torch.nn.LeakyReLU(),
#torch.nn.Dropout(p=dropout_p), #torch.nn.Dropout(p=dropout_p),
torch.nn.MaxPool1d(2, stride = 1) torch.nn.MaxPool1d(2, stride = 1)
] ]
......
import torch.optim import torch.optim
def optimizer(cfg, model): def optimizer(cfg, model):
result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-2)} result = {"Adam" : torch.optim.Adam(model.parameters(), lr = 1e-3)}
return result[cfg["Optimizer"]] return result[cfg["Optimizer"]]
No preview for this file type
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment