In this work log, I explore data-parallel distributed training in Keras. I try different configurations of GPU count (1, 2, 4 or 8 GPUs) and total (original)/effective (per GPU) batch size, increase the dataset size, and compare evaluation methods. I include my notes and ideas for next steps from different experiments to show a realistic research process.

Initial observations

Data and model

The core model trains a basic 7-layer convolutional network predicting one of 10 animal classes (bird, mammal, reptile, etc) on class-balanced subsets of iNaturalist 2017, typically 5000 train / 800 test for fast iteration. You can read more about this task with more powerful models or in the context of curriculum learning.


The code to train a data-parallel model is in this example gist.

Basic Multi-GPU in Keras

Smaller batches for accuracy, larger for speed on 2 GPUs

Keras has a built-in function for model parallelization: mutli_gpu_model in utils. This is trivial to enable—here are examples on a local machine and GCP with 2 GPUs.


Next steps

Scaling to 4 GPUs

Speed up: 4 GPUs: 2.5X, 2 GPUs: 1.6X

2.5X faster training on 4 GPUs vs 1 GPU


Train a 7-layer convnet on main iNaturalist dataset (5000 train / 800 val) as a proof of concept for the Keras multi_gpu_model function.


Next steps

Batch size matters more than GPU count

Use smaller batches when acc matters

Notes/next steps

Training on 2 GPUs with 10X the data

If time-constrained, train with larger batches on less data

Compare performance when training with 5K (blues) vs 50K (red) images on 2 GPUs. For a fixed 50 epochs, the increase in training time is linear with the increase in dataset size and doesn't significantly improve this particular model. The 50K version plateaus in the same amount of time it takes the 5K version to finish training (and barely start to plateau). The 50K version does reach a slightly higher max validation accuracy (49% vs 45% for the 5K case), though this decays with further training. The effect of parallelization and of using 10X more data is much more obvious when looking at the time taken to train than at epochs seen.

Note that for a given amount of training time (up to 3.5 hours), training on 5K examples with larger batch sizes outperforms training on 50K examples on validation accuracy. The 50K case eventually surpasses the 5K cases, but the difference in max validation accuracy reached is only about 4-5%.

Next steps

Train 2.5X faster on 4 GPUs, 3X on 8 GPUs?

Linear-ish speedup with distributed training

Distributing over 4 GPUs, even for such a small network and dataset, gives a 2.5X speed-up relative to 1 GPU. 2 GPUs gives a 1.6X speed-up relative to 1 GPU. Overall, the improvement is not linear with GPU count but still substantial. Increasing the compute to 8 GPUs doesn't improve the runtime with the batch sizes tried so far and goes against the overall trend of reducing compute time.

Note that the 8GPU runs are not directly comparable, as they use double the training data and more than double the validation data.

Next steps

Optimal batch size & GPU count? It depends.

Find an optimal batch size for the problem

An original batch size of 64 still does best on average, compared to scaling for 8 GPUs. E.g., it may be better to run on 4 GPUs with original batch size 64, subbatch size 16, than on 8 GPUs with batch size 64, subbatch size 8, or on 4 GPUs with batch size 256, subbatch size 64 (which is what Keras would recommend).

8 GPUs with the max batch size 512, subbatch size 64 is still best overall (assuming one has access to this extra compute and is willing to explore the best configuration for a particular problem).

Tentative: Not much impact from fixed subbatch size & scaled batch size

Subbatch size matters at 8 GPUs, not before?

Keras recommends increasing the original batch size proportionally with GPU count. In this scenario, the subbatch size is fixed at 64 and the original model's batch size scaled up accordingly from 1 to 8 GPUs.

On 8GPUs, the training is ~3x faster. There is no noticeable speedup between 1, 2, and 4 GPUs. This is surprising—perhaps the model needs to be more complex or the data load heavier.

Tentative: Vary subbatch size on 8GPUs

Maximize batch size to be safe?

Larger batch size appears more reliable. Note the long relatively flat stretches of the maroon and light blue lines, as if computation temporarily slowed down (issues with the cloud, perhaps?). Need to run more experiments before drawing solid conclusions