From 43caf162d0e56244f5f5f2647c770ee86d73b659 Mon Sep 17 00:00:00 2001
From: Yandi <yandirzm@gmail.com>
Date: Fri, 3 Feb 2023 16:50:59 +0100
Subject: [PATCH] parametrized dropout in RNN model

---
 config.yml | 1 +
 model.py   | 3 ++-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/config.yml b/config.yml
index 9f40301..cbca0cb 100644
--- a/config.yml
+++ b/config.yml
@@ -58,6 +58,7 @@ RNN:
   HiddenSize: 35
   NumLayers: 4
   NumFFN: 15
+  Dropout: 0.2
   Initialization: None
 
 #Name of directory containing logs
diff --git a/model.py b/model.py
index f3860b6..4a3e033 100644
--- a/model.py
+++ b/model.py
@@ -43,6 +43,7 @@ class RNN(nn.Module):
         self.hidden_size = cfg["RNN"]["HiddenSize"]
         self.num_layers = cfg["RNN"]["NumLayers"]
         self.num_ffn = cfg["RNN"]["NumFFN"]
+        self.dropout = cfg["RNN"]["Dropout"]
         # RNN
         self.rnn = nn.Sequential(
             nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu'),
@@ -62,7 +63,7 @@ class RNN(nn.Module):
             )
             self.fc.add_module(
                 f"dropout_{layer}", 
-                nn.Dropout(p=0.2)
+                nn.Dropout(p=self.dropout)
             )
         
         self.fc.add_module(
-- 
GitLab