Skip to content
Snippets Groups Projects
Commit 43caf162 authored by Yandi's avatar Yandi
Browse files

parametrized dropout in RNN model

parent 070d97c9
Branches
No related tags found
1 merge request!1Master into main
...@@ -58,6 +58,7 @@ RNN: ...@@ -58,6 +58,7 @@ RNN:
HiddenSize: 35 HiddenSize: 35
NumLayers: 4 NumLayers: 4
NumFFN: 15 NumFFN: 15
Dropout: 0.2
Initialization: None Initialization: None
#Name of directory containing logs #Name of directory containing logs
......
...@@ -43,6 +43,7 @@ class RNN(nn.Module): ...@@ -43,6 +43,7 @@ class RNN(nn.Module):
self.hidden_size = cfg["RNN"]["HiddenSize"] self.hidden_size = cfg["RNN"]["HiddenSize"]
self.num_layers = cfg["RNN"]["NumLayers"] self.num_layers = cfg["RNN"]["NumLayers"]
self.num_ffn = cfg["RNN"]["NumFFN"] self.num_ffn = cfg["RNN"]["NumFFN"]
self.dropout = cfg["RNN"]["Dropout"]
# RNN # RNN
self.rnn = nn.Sequential( self.rnn = nn.Sequential(
nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu'), nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu'),
...@@ -62,7 +63,7 @@ class RNN(nn.Module): ...@@ -62,7 +63,7 @@ class RNN(nn.Module):
) )
self.fc.add_module( self.fc.add_module(
f"dropout_{layer}", f"dropout_{layer}",
nn.Dropout(p=0.2) nn.Dropout(p=self.dropout)
) )
self.fc.add_module( self.fc.add_module(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment