diff --git a/config.yml b/config.yml index 9f40301bdc55237a56ca89640b59f58f4973151e..cbca0cb83d1a6d7b341e357040cc9a3b9c33c309 100644 --- a/config.yml +++ b/config.yml @@ -58,6 +58,7 @@ RNN: HiddenSize: 35 NumLayers: 4 NumFFN: 15 + Dropout: 0.2 Initialization: None #Name of directory containing logs diff --git a/model.py b/model.py index f3860b6515fffe682592cdb8db547d06fa89fe4c..4a3e03387a6918c12f87f8cdb0438bbb9497a26d 100644 --- a/model.py +++ b/model.py @@ -43,6 +43,7 @@ class RNN(nn.Module): self.hidden_size = cfg["RNN"]["HiddenSize"] self.num_layers = cfg["RNN"]["NumLayers"] self.num_ffn = cfg["RNN"]["NumFFN"] + self.dropout = cfg["RNN"]["Dropout"] # RNN self.rnn = nn.Sequential( nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu'), @@ -62,7 +63,7 @@ class RNN(nn.Module): ) self.fc.add_module( f"dropout_{layer}", - nn.Dropout(p=0.2) + nn.Dropout(p=self.dropout) ) self.fc.add_module(