SLML Part 7 - SketchRNN Experiments: Data Augmentation
SLML Part 7 - SketchRNN Experiments: Data Augmentation
This post is part 7 of “SLML” - Single-Line Machine Learning.
To read the previous post, check out part 6.
Preprocessing for Data Augmentation
I was curious about why my earlier experiment using stroke augmentation didn’t show benefits (and in some cases made the models perform much worse on validation metrics). In Figure 1, it’s clear that the pink line (without stroke augmentation) has a faster-growing validation loss after the point of overfitting.
Those experiments used stroke augmentation as a hyperparameter on the dataset - so at runtime when the model started training, it would modify the drawings before assembling them into batches for the trainer.
I decided to create a dataset where I ran the augmentation ahead of time and saved the result so I could inspect the results more closely. For each entry in the dataset, I included the original drawing as well as a seperate entry for the drawing in reverse with a “flipped” sequence. This doubled the size of the dataset.
Then I took each drawing and randomly:
- Applied stroke agumentation to drop points, with a probability up to 0.5.
- Randomly rotated -15 to 15 degrees
- Randomly scaled between 100% to 120% of original size.
Some examples of the augmentations are visible in Figure 2.
Training With Augmented Dataset
Comparing the validation loss metrics in Figure 3 from the models trained on augmented dataset (purple) with my previous round of best-performing models, the augmented dataset takes longer to converge but the validation loss keeps sloping downwards. This is encouraging to me since it seems like the more diverse drawings in the augmented dataset are helping the model learn to generalize more than pervious models trained on non-augmented datasets.
20240221-dataaug10x
(purple).
As a small side experiment, I wanted to confirm my finding in part 4 that layer norm caused a meaningful improvement. Looking at the loss metrics in Figure 4, it appears that disabling layernorm (light green) causes a significant drop in performance. Disabling recurrent dropout doesn’t have a significant effect, as far as I can tell.
20240221-dataaug10x
(purple) compared to variants without layernorm (light green) and without recurrent dropout (magenta / dark red).