Skip to content
Snippets Groups Projects
Commit 966a099e authored by Yandi's avatar Yandi
Browse files

RNN dropout 35 4

parent 6e8a88b4
No related branches found
No related tags found
1 merge request!1Master into main
......@@ -11,8 +11,8 @@ def makejob(commit_id, model, nruns, time_wall):
#SBATCH --nodes=1
#SBATCH --partition=gpu_prod_night
#SBATCH --time={time_wall}
#SBATCH --output=logslurms/slurm-%A_%a.out
#SBATCH --error=logslurms/slurm-%A_%a.err
#SBATCH --output=logslurms/slurm-{model}%A_%a.out
#SBATCH --error=logslurms/slurm-{model}%A_%a.err
#SBATCH --array=0-{nruns}
......@@ -66,7 +66,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--time_wall",
default="no_limit",
help="Time wall. Choose in [no-limit, hour, half, quarter]")
help="Time wall. Choose in [no_limit, hour, half, quarter]")
parser.add_argument("--model_name",
default ="Bi-LSTM",
......
File added
File added
File added
File added
......@@ -1333,3 +1333,30 @@ INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root: - The train fold has 542473 samples
INFO:root: - The valid fold has 134687 samples
INFO:root:= Dataloaders for mean and standard deviation
INFO:root: - Dataset creation
INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
INFO:root:Generating the index
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root: - Loaded a dataset with 677160 samples
INFO:root: - Splitting the data in training and validation sets
INFO:root:Generating the subset files from 677160 samples
INFO:root: - Subset dataset
INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root: - The train fold has 541965 samples
INFO:root:= Dataloaders
INFO:root: - Dataset creation
INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
INFO:root:Generating the index
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root: - Loaded a dataset with 677160 samples
INFO:root: - Splitting the data in training and validation sets
INFO:root:Generating the subset files from 677160 samples
INFO:root: - Subset dataset
INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root:The loaded dataset contains 25 latitudes, 37 longitudes, 28 depths and 2222 time points
INFO:root:Loading the index from sub_2CMEMS-MEDSEA-2010-2016-training.nc.bin_index.idx
INFO:root: - The train fold has 541837 samples
INFO:root: - The valid fold has 135323 samples
......@@ -43,9 +43,15 @@ class RNN(nn.Module):
self.hidden_size = cfg["RNN"]["HiddenSize"]
self.num_layers = cfg["RNN"]["NumLayers"]
# RNN
self.rnn = nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu')
self.rnn = nn.Sequential(
nn.RNN(input_size, self.hidden_size, self.num_layers, batch_first=True, nonlinearity='relu'),
nn.Dropout(p=0.2)
)
self.fc = nn.Linear(self.hidden_size, 1)
self.fc = nn.Sequential(
nn.Linear(self.hidden_size, 1),
nn.Dropout(p=0.2)
)
def forward(self, x):
use_cuda = torch.cuda.is_available()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment