SLML Part 4 - SketchRNN Experiments: Minimum Stroke Length and RNN Regularization
SLML Part 4 - SketchRNN Experiments: Minimum Stroke Length and RNN Regularization
This post is part 4 of “SLML” - Single-Line Machine Learning.
To read the previous post, check out part 3. If you want to keep reading, check out part 5.
Now that I had a repeatable process to convert JPEGs into stroke-3, I decided to start training models with my first dataset (which I called look_i16
).
Filtering out Short Strokes
The training data in my first dataset, look_i16
, had lots of tiny strokes mixed in with longer ones, and no inherent order to them.
Unsurprisingly, my first models produced odd smatterings composed of many small strokes.
As a quick experiment, I tried just filtering out any really short strokes out of the dataset - I decided to iterate through each of the 1300 drawings in the training set, and filter out any strokes with less than 10 points - I called this dataset look_i16__minn10
.
Note that the average number of strokes in look_i16
is around 40, while it’s closer to 10 in look_i16__minn10
. It seems that there were many very short strokes in the training data. I also simplified the drawings more aggressively for look_i16__minn10
by increasing RDP’s epsilon
parameter to 1.0
when preprocessing the data, which further reduced the number of points per drawing.
When training models, I’m measuring the reconstruction loss. I feed in a drawing, the model encodes it into an embedding vector, and then decodes the embedding vector back into a drawing. I can compute the loss at each step in the reconstructed drawing compare to the original. Periodically, after processing several batches of training data, I compute the reconstruction metrics on a validation set. This is a portion of the dataset I’m not using to actually update the weights of the model during training.
By comparing the reconstruction loss on the training set vs. the validation set over time, I can identify when the model starts “overfitting”. Intuitively, if the model is learning to perform better on the training data while performing worse on the validation data, that means it is effectively memorizing the training set rather than learning to generalize its learnings to drawings it wasn’t trained on.
The model trained on look_i16__minn10
performed slightly better than the model trained on look_i16
in terms of the error when reconstructing a drawing. It’s visible in Figure 4 that the loss values were lower, and the validation loss didn’t start to increase until slightly later.
The results produced after training on look_i16__minn10
were much less chaotic. While they didn’t resemble coherent drawings, this was the first time I spotted some elements of my drawing style (head shape, eye style, lips, chin).
Layer Normalization
The Magenta team had recommended using Layer Normalization and Recurrent Dropout with Memory Loss.
I noticed that when I let the trainer run overnight, I’d get wild spikes in the training loss. I decided to start with Layer Normalization.
Adding layer normalization showed a dramatic difference, visible in Figure 6. The yellow line (without layer norm) has a massive increase in validation loss, and many big spikes, while the burgundy and gray lines (with layer norm integrated) have much lower validation loss and don’t have any comparable spikes.
Recurrent Dropout
The results from the layernorm models in Figure 7 had some hints of my drawing style, while still struggling to form coherent drawings.
Next, I kept layernorm enabled and enabled recurrent dropout. I ran one separate runs with and without stroke augmentation enabled.
Compared to the layernorm-only models (burgundy and gray), the one recurrent dropout model (green) achieved a lower validation loss relatively quickly, before starting to overfit.
The model trained with recurrent dropout and stroke augmentation (pink) clearly performed worse than with recurrent dropout (in terms of higher validation loss).
The resulting generations from model with layernorm and recurrent dropout in Figure 9 weren’t obviously better or worse than those from the layernorm-only model in Figure 7.
If you want to keep reading, check out part 5 of my SLML series.