Commit c3dbc013 authored by Simon's avatar Simon
Browse files

update to 2020-2021 version

parent 80c2e7e2
......@@ -382,7 +382,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.8.3"
}
},
"nbformat": 4,
......
%% Cell type:markdown id: tags:
# Preliminaries
%% Cell type:markdown id: tags:
## Environment setup
### Mount Google Drive
%% Cell type:code id: tags:
``` python
from google.colab import drive
drive.mount('/content/drive')
```
%% Cell type:markdown id: tags:
### Get some utilities
%% Cell type:code id: tags:
``` python
import os
if not os.path.exists('mel_features.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/mel_features.py
if not os.path.exists('utils.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/utils.py
if not os.path.exists('vggish_params.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/vggish_params.py
```
%% Cell type:markdown id: tags:
### Define important paths
%% Cell type:code id: tags:
``` python
ust_data_dir = './drive/My Drive/data/ust-data'
dataset_dir = os.path.join(ust_data_dir, 'sonyc-ust')
annotation_file = os.path.join(dataset_dir, 'annotations.csv')
taxonomy_file = os.path.join(dataset_dir, 'dcase-ust-taxonomy.yaml')
log_mel_spec_dir = os.path.join(ust_data_dir, 'log-mel-spectrograms')
output_training_dir = os.path.join(ust_data_dir, 'output_training')
output_prediction_dir = os.path.join(ust_data_dir, 'output_prediction')
```
%% Cell type:markdown id: tags:
### Install missing packages
%% Cell type:code id: tags:
``` python
!pip install oyaml
```
%% Cell type:markdown id: tags:
## Exploring the dataset
%% Cell type:markdown id: tags:
We will use [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) to manipulate the dataset.
%% Cell type:code id: tags:
``` python
import pandas as pd
import oyaml as yaml
from utils import get_file_targets, get_subset_split
import numpy as np
```
%% Cell type:code id: tags:
``` python
# Create a Pandas DataFrame from the annotation CSV file
annotation_data = pd.read_csv(annotation_file).sort_values('audio_filename')
# You can view the top rows of the frame with
annotation_data.head()
```
%% Cell type:code id: tags:
``` python
# List of all audio files
file_list = annotation_data['audio_filename'].unique().tolist()
```
%% Cell type:code id: tags:
``` python
# Load taxonomy
with open(taxonomy_file, 'r') as f:
taxonomy = yaml.load(f, Loader=yaml.Loader)
# get list of labels from taxonomy
labels = ["_".join([str(k), v]) for k,v in taxonomy['coarse'].items()]
# number of classes
n_classes = len(labels)
print(labels)
```
%% Cell type:code id: tags:
``` python
# get list of one-hot encoded labels for all audio files
target_list = get_file_targets(annotation_data, labels)
# get list of idices for the training, validation and test subsets
train_file_idxs, val_file_idxs, test_file_idxs = get_subset_split(annotation_data)
```
%% Cell type:markdown id: tags:
For each split (training, validation, test) and each label, we compute the proportion of files that contain this label.
%% Cell type:code id: tags:
``` python
train_proportions = np.sum(target_list[train_file_idxs,:],
axis=0)/len(train_file_idxs)
val_proportions = np.sum(target_list[val_file_idxs,:],
axis=0)/len(val_file_idxs)
test_proportions = np.sum(target_list[test_file_idxs,:],
axis=0)/len(test_file_idxs)
print('Distribution of classes in the training set:')
for idx, label in enumerate(labels):
print(label+': {:.2%}'.format(train_proportions[idx]))
print('\n')
print('Distribution of classes in the validation set:')
for idx, label in enumerate(labels):
print(label+': {:.2%}'.format(val_proportions[idx]))
print('\n')
print('Distribution of classes in the test set:')
for idx, label in enumerate(labels):
print(label+': {:.2%}'.format(test_proportions[idx]))
```
%% Cell type:markdown id: tags:
---
### Question
What conclusions can we draw from the distribution of classes in the training set?
---
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Audio basics
We will use two libraries for loading and playing audio signals:
1. [Librosa](https://librosa.github.io/librosa/index.html) is a Python package for music and audio processing.
2. [IPython.display.Audio](https://ipython.org/ipython-doc/stable/api/generated/IPython.display.html#IPython.display.Audio) lets you play audio directly in notebooks.
2. [PySoundFile](https://pysoundfile.readthedocs.io/en/0.8.1) is an audio library based on libsndfile, CFFI and NumPy.
3. [IPython.display.Audio](https://ipython.org/ipython-doc/stable/api/generated/IPython.display.html#IPython.display.Audio) lets you play audio directly in notebooks.
%% Cell type:markdown id: tags:
### Reading audio
Use [`librosa.load`](https://librosa.github.io/librosa/generated/librosa.core.load.html#librosa.core.load) to load an audio file into an audio array. Return both the audio array as well as the sample rate:
%% Cell type:code id: tags:
``` python
import librosa
# get a file in the training set
training_file_list = [file_list[ind] for ind in train_file_idxs]
audio_file = os.path.join(dataset_dir, 'audio-dev/train',
training_file_list[10])
x, sr = librosa.load(audio_file, mono=True, sr=None)
```
%% Cell type:markdown id: tags:
Display the length of the audio array and sample rate:
%% Cell type:code id: tags:
``` python
print(x.shape)
print(sr)
```
%% Cell type:code id: tags:
``` python
import resampy
import vggish_params
old_sr = sr
sr = vggish_params.SAMPLE_RATE
x = resampy.resample(x, old_sr, sr)
```
%% Cell type:markdown id: tags:
### Visualizing Audio
%% Cell type:markdown id: tags:
In order to display plots inside the Jupyter notebook, run the following commands:
%% Cell type:code id: tags:
``` python
import matplotlib.pyplot as plt
```
%% Cell type:code id: tags:
``` python
time_axis = np.arange(0,x.shape[0]/sr, 1/sr)
plt.figure(figsize=(7, 3))
plt.plot(time_axis, x)
plt.title('waveform')
plt.ylabel('amplitude')
plt.xlabel('time (s)')
```
%% Cell type:markdown id: tags:
### Playing Audio
%% Cell type:markdown id: tags:
Using [`IPython.display.Audio`](http://ipython.org/ipython-doc/2/api/generated/IPython.lib.display.html#IPython.lib.display.Audio), you can play an audio file:
%% Cell type:code id: tags:
``` python
import IPython.display as ipd
ipd.Audio(x, rate=sr) # load a local WAV file
```
%% Cell type:markdown id: tags:
### Writing Audio
%% Cell type:markdown id: tags:
[`librosa.output.write_wav`](https://librosa.github.io/librosa/generated/librosa.output.write_wav.html#librosa.output.write_wav) saves a NumPy array to a WAV file.
[`soundfile.write`](https://pysoundfile.readthedocs.io/en/0.8.1/#soundfile.write) saves a NumPy array to a WAV file.
%% Cell type:code id: tags:
``` python
librosa.output.write_wav('example.wav', x, sr)
import soundfile as sf
sf.write('example.wav', x, sr)
```
%% Cell type:markdown id: tags:
## Mel spectrogram
%% Cell type:markdown id: tags:
In this project, we will work with a time-frequency representation of audio signals called the Mel spectrogram. It is computed as follows:
%% Cell type:markdown id: tags:
#### Framing
The waveform is converted into into a sequence of successive overlapping frames.
%% Cell type:code id: tags:
``` python
# Define the parameters of the short-term analysis
window_length_secs = vggish_params.STFT_WINDOW_LENGTH_SECONDS
hop_length_secs = vggish_params.STFT_HOP_LENGTH_SECONDS
window_length_samples = int(round(sr * window_length_secs))
hop_length_samples = int(round(sr * hop_length_secs))
num_samples = x.shape[0]
num_frames = 1 + int(np.floor((num_samples - window_length_samples) /
hop_length_samples))
# Create an array of shape (window_length_samples, num_frames) where each column
# contains a frame of the original audio signal
shape = (num_frames, window_length_samples)
strides = (x.strides[0] * hop_length_samples,) + x.strides
X_frames = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).T
print(X_frames.shape)
```
%% Cell type:markdown id: tags:
#### Windowing
%% Cell type:markdown id: tags:
Each frame is multiplied with a smooth analysis window.
%% Cell type:code id: tags:
``` python
window = .5 - (0.5 * np.cos(2 * np.pi / window_length_samples *
np.arange(window_length_samples))) # "periodic" Hann
X_windowed_frames = X_frames * window[:,np.newaxis]
plt.figure()
plt.plot(window)
print(X_windowed_frames.shape)
plt.title('analysis window')
plt.xlabel('samples')
```
%% Cell type:markdown id: tags:
#### Discrete Fourier transform
%% Cell type:markdown id: tags:
The short-term Fourier transform (STFT) is computed by applying the discrete Fourier transform (DFT) on each windowed frame. The magnitude spectrogram is obtained by taking the modulus of the STFT matrix.
%% Cell type:code id: tags:
``` python
import librosa.display
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
X_stft = np.fft.rfft(X_windowed_frames, int(fft_length), axis=0)
X_spec = np.abs(X_stft)
plt.figure(figsize=(14, 7))
librosa.display.specshow(librosa.amplitude_to_db(X_spec), sr=sr,
hop_length=hop_length_samples, x_axis='time',
y_axis='hz')
# This is basically equivalent to:
# librosa.display.specshow(20*np.log10(X_spec), sr=sr,
# hop_length=hop_length_samples, x_axis='time',
# y_axis='hz')
# plt.clim(-60,25)
plt.colorbar()
plt.title('dB-scaled spectrogram')
plt.xlabel('time (s)')
plt.ylabel('frequency (Hz)')
```
%% Cell type:markdown id: tags:
#### Mel filterbank
%% Cell type:markdown id: tags:
A filterbank matrix is created to map DFT-frequency bins into Mel-frequency bins
%% Cell type:code id: tags:
``` python
import mel_features
lower_edge_hertz = vggish_params.MEL_MIN_HZ
upper_edge_hertz = vggish_params.MEL_MAX_HZ
num_mel_bins = vggish_params.NUM_MEL_BINS
spec_to_mel_mat = mel_features.spectrogram_to_mel_matrix(num_mel_bins=num_mel_bins,
num_spectrogram_bins=X_spec.shape[0],
audio_sample_rate=sr,
lower_edge_hertz=lower_edge_hertz,
upper_edge_hertz=upper_edge_hertz)
print(spec_to_mel_mat.T.shape)
plt.figure(figsize=(14, 7))
plt.imshow(spec_to_mel_mat.T, origin='lower')
plt.colorbar(orientation='horizontal')
plt.set_cmap('magma')
plt.title('Mel filterbank matrix')
plt.xlabel('DFT-frequency bins')
plt.ylabel('Mel-frequency bins')
```
%% Cell type:markdown id: tags:
#### Mel spectrogram
%% Cell type:markdown id: tags:
---
### Question
How do you obtain the Mel spectrogram from the filterbank matrix and the spectrogram?
---
%% Cell type:code id: tags:
``` python
X_mel_spec = # TODO
plt.figure(figsize=(14, 7))
librosa.display.specshow(librosa.amplitude_to_db(X_mel_spec), sr=sr,
hop_length=hop_length_samples, x_axis='time')
plt.set_cmap('magma')
plt.colorbar()
plt.title('dB-scaled Mel-spectrogram')
plt.xlabel('time (s)')
plt.yticks(np.arange(0,num_mel_bins,10))
plt.ylabel('Mel-frequency bins')
```
%% Cell type:markdown id: tags:
---
### Questions
1. What is the Mel scale?
2. Explain the effect of the Mel filterbank matrix on the time-frequency representation? What happens to the low and high frequencies?
3. Compare the Mel-spectrograms of several audio files in the dataset with different
labels. This is just to observe that different audio events have different time-frequency patterns allowing you and probably the machine learning system you will develop to discriminate between different sounds.
---
......
%% Cell type:markdown id: tags:
# Convolutional neural network and transfer learning for urban sound tagging
In this notebook, you will build and train a convolutional neural network (CNN) to perform urban sound tagging with [Keras](https://keras.io/). Using transfer learning, your CNN will build upon a model called [VGGish](https://github.com/tensorflow/models/tree/master/research/audioset/vggish). It was trained on [AudioSet](https://github.com/tensorflow/models/tree/master/research/audioset), a dataset of over 2 million human-labeled 10-second YouTube video soundtracks, with labels taken from an ontology of more than 600 audio event classes. This represents more than 5 thousand hours of audio.
The method you will implement here is based on ["Convolutional Neural Networks with Transfer Learning for Urban Sound Tagging"](http://dcase.community/documents/challenge2019/technical_reports/DCASE2019_Kim_107.pdf) that was proposed by Bongjun Kim (Department of Computer Science, Northwestern University, Evnaston, Illinois, USA) and obtained the 3rd best score at the [DCASE 2019 Challenge, task 5](http://dcase.community/challenge2019/task-urban-sound-tagging).
**Before working on the rest of this notebook, take some time to read and understand what are VGGish, Audioset, and the above-mentioned method that was submitted to the DCASE 2019 challenge.**
%% Cell type:markdown id: tags:
## Environment setup
### Mount Google Drive
%% Cell type:code id: tags:
``` python
from google.colab import drive
drive.mount('/content/drive')
```
%% Cell type:markdown id: tags:
### Get some utilities
%% Cell type:code id: tags:
``` python
import os
if not os.path.exists('mel_features.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/mel_features.py
if not os.path.exists('utils.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/utils.py
if not os.path.exists('vggish_params.py'):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/vggish_params.py
```
%% Cell type:markdown id: tags:
### Define important paths
%% Cell type:code id: tags:
``` python
ust_data_dir = './drive/My Drive/data/ust-data'
dataset_dir = os.path.join(ust_data_dir, 'sonyc-ust')
annotation_file = os.path.join(dataset_dir, 'annotations.csv')
taxonomy_file = os.path.join(dataset_dir, 'dcase-ust-taxonomy.yaml')
log_mel_spec_dir = os.path.join(ust_data_dir, 'log-mel-spectrograms')
output_training_dir = os.path.join(ust_data_dir, 'output_training')
output_prediction_dir = os.path.join(ust_data_dir, 'output_prediction')
```
%% Cell type:markdown id: tags:
### Download features (optional)
If you haven't finished extracting the features yet, you can download them in order to continue working.
%% Cell type:code id: tags:
``` python
# if not(os.path.isdir(log_mel_spec_dir)):
# os.makedirs(log_mel_spec_dir)
# %pushd /content/drive/My\ Drive/data/ust-data/log-mel-spectrograms
# !wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/data.npy
# %popd
```
%% Cell type:markdown id: tags:
### Install missing packages
%% Cell type:code id: tags:
``` python
!pip install oyaml
```
%% Cell type:markdown id: tags:
## Define parameters
%% Cell type:code id: tags:
``` python
%tensorflow_version 1.x
import os
import pandas as pd
import oyaml as yaml
import numpy as np
from utils import get_file_targets, get_subset_split
import datetime
import pytz
import json
import keras
from keras.models import Model
from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, GlobalMaxPooling1D, Reshape
from keras.optimizers import Adam
import resampy
import vggish_params
from IPython.display import clear_output
```
%% Cell type:markdown id: tags:
In the following cell, you have to set several hyperparameters of the learning algorithm.
%% Cell type:code id: tags:
``` python
model_name = 'my_model'
learning_rate = 10
batch_size = 2000
num_epochs = 10
patience = 1
learning_rate = ? # learning rate for gradient descent (actually, Adam optimization method)
batch_size = ? # size of the mini-batches
num_epochs = ? # number of epochs
patience = ? # early stopping patience
tz_Paris = pytz.timezone('Europe/Paris')
datetime_Paris = datetime.datetime.now(tz_Paris)
timestamp = datetime_Paris.strftime("%Y-%m-%d-%Hh%Mm%Ss")
exp_id = model_name + '_' + timestamp
print(exp_id)
```
%% Cell type:code id: tags:
``` python
# save parameters to disk
params = {'annotation_file': annotation_file,
'taxonomy_file': taxonomy_file,
'exp_id': exp_id,
'log_mel_spec_dir': log_mel_spec_dir,
'output_dir': output_training_dir,
'learning_rate': learning_rate,
'batch_size': batch_size,
'batch_size': batch_size,
'num_epochs': num_epochs,
'patience': patience}
results_dir = os.path.join(output_training_dir, exp_id)
os.makedirs(results_dir, exist_ok=True)
kwarg_file = os.path.join(results_dir, "hyper_params.json")
with open(kwarg_file, 'w') as f:
json.dump(params, f, indent=2)
```
%% Cell type:markdown id: tags:
## Load annotations and taxonomy
%% Cell type:code id: tags:
``` python
# Create a Pandas DataFrame from the annotation CSV file
annotation_data = pd.read_csv(annotation_file).sort_values('audio_filename')
# List of all audio files
file_list = annotation_data['audio_filename'].unique().tolist()
# Load taxonomy
with open(taxonomy_file, 'r') as f:
taxonomy = yaml.load(f, Loader=yaml.Loader)
# get list of labels from taxonomy
labels = ["_".join([str(k), v]) for k,v in taxonomy['coarse'].items()]
# list of one-hot encoded labels for all audio files
target_list = get_file_targets(annotation_data, labels)
# list of idices for the training, validation and test subsets
train_file_idxs, val_file_idxs, test_file_idxs = get_subset_split(annotation_data)
# number of classes
n_classes = len(labels)
```
%% Cell type:markdown id: tags:
## Load log-Mel spectrograms
%% Cell type:code id: tags:
``` python
how_saved = 'global' # 'individual' or 'global'
if how_saved == 'global':
log_mel_spec_list = list(np.load(os.path.join(log_mel_spec_dir, 'data.npy')))
elif how_saved == 'individual':
# Create a list of log-Mel spectrograms of size 998 frames × 64 Mel-frequency
log_mel_spec_list = []
for idx, filename in enumerate(file_list):
clear_output(wait=True)
log_mel_path = os.path.join(log_mel_spec_dir, os.path.splitext(filename)[0] + '.npy')
log_mel_spec = np.load(log_mel_path)
log_mel_spec_list.append(mel_spec)
log_mel_spec_list.append(log_mel_spec)
print('({}/{})'.format(idx+1, len(file_list)))
```
%% Cell type:code id: tags:
``` python
# Create training set (input, output) pairs
train_x = []
train_y = []
for idx in train_file_idxs:
train_x.append(mel_spec_list[idx])
train_x.append(log_mel_spec_list[idx])
train_y.append(target_list[idx])
perm_train_idxs = np.random.permutation(len(train_x))
train_x = np.array(train_x)[perm_train_idxs]
train_y = np.array(train_y)[perm_train_idxs]
# Create validation set (input, output) pairs
val_x = []
val_y = []
for idx in val_file_idxs:
val_x.append(mel_spec_list[idx])
val_x.append(log_mel_spec_list[idx])
val_y.append(target_list[idx])
perm_val_idxs = np.random.permutation(len(val_x))
val_x = np.array(val_x)[perm_val_idxs]
val_y = np.array(val_y)[perm_val_idxs]
```
%% Cell type:markdown id: tags:
## VGGish
%% Cell type:markdown id: tags:
[VGGish](https://github.com/tensorflow/models/tree/master/research/audioset/vggish) is a variant of the [VGG](https://arxiv.org/abs/1409.1556) model, in
particular Configuration A with 11 weight layers. Specifically, here are the
changes that were made:
* The input size was changed to 96x64 for log mel spectrogram audio inputs.
* The last group of convolutional and maxpool layers was dropped, so we now have
only four groups of convolution/maxpool layers instead of five.
* Instead of a 1000-wide fully connected layer at the end, 128-wide
fully connected layer was used. This acts as a compact embedding layer.
You will have access to a pre-trained VGGish Keras model. It was trained on [AudioSet](https://github.com/tensorflow/models/tree/master/research/audioset), a dataset of over 2 million human-labeled 10-second YouTube video soundtracks, with labels taken from an ontology of more than 600 audio event classes. This represents more than 5 thousand hours of audio.
In the following cell, you have to define the VGGish model in Keras, using the [Functional API](https://keras.io/models/model/). Look for the information you need in the [VGG](https://arxiv.org/abs/1409.1556) paper, and take the above-mentioned modifications into account.
Hint: Look at the imports to know which Keras layers you need.
%% Cell type:code id: tags:
``` python
input_shape = (96, 64, 1) # see vggish_params.py
img_input = Input( shape=input_shape)
x = # TODO: Define the VGGish model in Keras, looking at the VGG paper
# and taking into account the above-mentioned changes.
vggish_model = Model(img_input, x, name='vggish')
```
%% Cell type:markdown id: tags:
Your goal is to use the pre-trained VGGish model for transfer learning, i.e. adapting the model to your specific task and dataset.
In the following cell, you will load the pre-trained weights into your model. It will not work correctly if you did not define the proper architecture. You can look at the architecture of your model with `vggish_model.summary()`.
%% Cell type:code id: tags:
``` python
vggish_weights_file = 'vggish_weights.ckpt'
if not os.path.exists(vggish_weights_file):
!wget https://gitlab-research.centralesupelec.fr/sleglaive/embedded-ust-students/raw/master/vggish_weights.ckpt
vggish_model.load_weights(vggish_weights_file)
```
%% Cell type:markdown id: tags:
## Model adaptation
You cannot directly use the VGGish model, you have to adapt it to your problem:
You cannot directly use the previous VGGish model. You have to define a new model that is inspired from VGGish but also satisfies the following requirements:
- Input layer: it should match the dimension of your audio clips.
- Input layer: It should match the dimension of your audio clips.
- Convolutional layers (and intermediate max-pooling layers): You will use the same than VGGish. For that purpose, you have to initialize the convolutional layers of your new model with the parameters from the pre-trained VGGish model. During training, you have to freeze these parameters.
- Convolutional layers (and intermediate max-pooling layers): You will use the same than VGGish, up to ```conv4_2```. You will have to initialize the convolutional layers of your new model with the parameters from the pre-trained VGGish model. During training, you will freeze these parameters.
- Temporal pooling: we want to make clip-level predictions for arbitrary clip durations, so frame-level feature maps should be pooled (e.g. with a max pooling) along the time axis.
- Temporal pooling: We want to make clip-level predictions for arbitrary clip durations, so frame-level feature maps should be pooled (e.g. with a max pooling) along the time axis.
- Fully-connected layers: You will then have a few fully-connected layers that are randomly initialized and whose parameters are learned on the training set. Choose the output layer to perform 8-class prediction.
---
### Question
Why do we keep the convolutional layers only, and drop the fully-connected ones?
---
In the following cell, you first have to define this new model, then you will transfer some weights from VGGish. You can use the ```my_init``` to initialize the parameters of the dense layers (see the documentation).
%% Cell type:code id: tags:
``` python
my_init = keras.initializers.VarianceScaling(scale=1/3,
mode='fan_in',
distribution='uniform',
seed=None)
complete_model = # TODO
complete_model.summary()
```
%% Cell type:markdown id: tags:
You now have to perform transfer learning. The first 6 convolutional layers are initialized with parameters from convolutional layers
of the pre-trained VGGish model. During training, the first four convolutional layers are fixed (not updated) and the remaining ones
are fine-tuned on the training set. Use the ```get_weights()```, ```set_weights()``` methods and the ```.trainable``` property of Keras layers.
%% Cell type:code id: tags:
``` python
# List of layers to copy (use the name of the layers)
layers_to_copy = []
for layer in layers_to_copy:
# copy the weight, see 'get_layer', 'set_weights' and 'get_weights' in Keras
# List of layers to freeze
layers_to_freeze = []
for layer in layers_to_freeze:
# freeze the weights by setting the 'trainable' parameter to false
```
%% Cell type:markdown id: tags:
## Training
In Edit > Notebook settings or Runtime>Change runtime type, select GPU as Hardware accelerator.
---
### Question
What loss should you use?
---
%% Cell type:code id: tags:
``` python
loss = # TODO
# Set up callbacks for early stopping and monitoring the loss
cb = []
# checkpoint
model_weight_file = os.path.join(results_dir, 'best_model_weights.h5')
cb.append(keras.callbacks.ModelCheckpoint(model_weight_file,
save_weights_only=True,
save_best_only=True,
monitor='val_loss'))
# early stopping
cb.append(keras.callbacks.EarlyStopping(monitor='val_loss',
patience=patience))
# monitor losses
history_csv_file = os.path.join(results_dir, 'history.csv')
cb.append(keras.callbacks.CSVLogger(history_csv_file, append=True,
separator=','))
# Compile model using Adam optimizer
complete_model.compile(Adam(lr=learning_rate), loss=loss)
history = complete_model.fit(
x=train_x[:,:,:,np.newaxis], y=train_y, batch_size=batch_size, epochs=num_epochs,
validation_data=(val_x[:,:,:,np.newaxis], val_y), callbacks=cb, verbose=2)
# save architecture
with open(os.path.join(results_dir, 'model_architecture.json'), 'w') as json_file:
json_file.write(complete_model.to_json())
```
%% Cell type:code id: tags:
``` python
import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend({'training', 'validation'})
train_loss, = plt.plot(history.history['loss'], label='training loss')
val_loss, = plt.plot(history.history['val_loss'], label='validation loss')
plt.legend(handles=[train_loss, val_loss])
plt.xlabel('epochs')
plt.ylabel('loss')
```
......