SLML Part 7 - SketchRNN Experiments: Data Augmentation

Experiments training SketchRNN on my dataset of single-line drawings.

Andrew Look


January 4, 2024

SLML Part 7 - SketchRNN Experiments: Data Augmentation

This post is part 7 of “SLML” - Single-Line Machine Learning.

To read the previous post, check out part 6.

Preprocessing for Data Augmentation

I was curious about why my earlier experiment using stroke augmentation didn’t show benefits (and in some cases made the models perform much worse on validation metrics). In Figure 1, it’s clear that the pink line (without stroke augmentation) has a faster-growing validation loss after the point of overfitting.

Figure 1: Training and validation loss metrics for layernorm-only models (burgundy and gray) alongside models trained with recurrent dropout (green) and with stroke augmentation (pink).

Those experiments used stroke augmentation as a hyperparameter on the dataset - so at runtime when the model started training, it would modify the drawings before assembling them into batches for the trainer.

I decided to create a dataset where I ran the augmentation ahead of time and saved the result so I could inspect the results more closely. For each entry in the dataset, I included the original drawing as well as a seperate entry for the drawing in reverse with a “flipped” sequence. This doubled the size of the dataset.

Another idea I’d like to try is applying local distortions or dilations, since that would change the directions certain strokes take without losing the overal subject of the drawing. Radial Basis Functions or Warp Grids seem like promising approaches to try.

Then I took each drawing and randomly:

  • Applied stroke agumentation to drop points, with a probability up to 0.5.
  • Randomly rotated -15 to 15 degrees
  • Randomly scaled between 100% to 120% of original size.

Some examples of the augmentations are visible in Figure 2.

Original drawing

After “Stroke Augmentation” drops points from lines at random

Randomly rotated and scaled
Figure 2: Examples of data augmentation.

Training With Augmented Dataset

Comparing the validation loss metrics in Figure 3 from the models trained on augmented dataset (purple) with my previous round of best-performing models, the augmented dataset takes longer to converge but the validation loss keeps sloping downwards. This is encouraging to me since it seems like the more diverse drawings in the augmented dataset are helping the model learn to generalize more than pervious models trained on non-augmented datasets.

Figure 3: Training and validation loss metrics from augmented dataset 20240221-dataaug10x (purple).

As a small side experiment, I wanted to confirm my finding in part 4 that layer norm caused a meaningful improvement. Looking at the loss metrics in Figure 4, it appears that disabling layernorm (light green) causes a significant drop in performance. Disabling recurrent dropout doesn’t have a significant effect, as far as I can tell.

Figure 4: Training and validation loss metrics from augmented dataset 20240221-dataaug10x (purple) compared to variants without layernorm (light green) and without recurrent dropout (magenta / dark red).