To efficiently scale modern deep learning models, you need a precise understanding of the interactions between hardware and algorithms — and this book explains how, step by step. Even grasping just the intuitive principles can tremendously help in both practice and research, with a focus on parallelism, performance bottlenecks, cost estimation, and compatibility with actual hardware. It also covers hands-on tutorials, JAX usage, TPU/Transformer architectures, and more — a must-read if you want to effectively work with large models!
1. Getting Started: Why Should You Care About Model Scaling?
Deep learning is a field with so much uncertainty that it's often called "black magic," but at least when it comes to efficiently scaling large models, relatively clear principles apply. This guide covers core principles applicable from a single accelerator to thousands of chips, aiming to directly help your work:
- You can roughly gauge how close each part of your model is to the theoretical peak performance.
- It helps you wisely choose a parallelism strategy suited to your situation and hardware.
- It enables you to estimate the time and cost needed to train and operate large Transformer models.
- It makes effective algorithm design possible, tailored to specific hardware characteristics.
- You can understand what performance limits hardware design actually hits.
"The goal of model scaling is to increase throughput linearly as you add more chips for training or inference."
Recently, models have been hitting the performance limits of hardware more frequently, and without efficient scaling, even groundbreaking research becomes difficult. In other words, model scaling capability has become essential knowledge.
2. The Real Keys to Model Scaling: Communication, Computation, and Memory
Parallelizing a model speeds up computation, but the communication cost between chips also increases. The moment it takes more time for chips to exchange data than to compute, no matter how many chips you add, things won't get any faster — this is the so-called "communication bound" bottleneck. You also need to consider maximum compute throughput (FLOP/s), memory bandwidth (read/write speeds), and total memory capacity.
"Even if a new TPU or GPU spec says 500 trillion operations per second, if the parameters are just shuttling around in memory, you might only get one-tenth of that performance in practice."
Understanding this interaction structure among computation, communication, and memory well allows you to:
- Predict where bottlenecks will occur, and
- Modify model design to avoid them.
From a hardware designer's perspective, the key challenge is minimizing cost while designing computing, bandwidth, and memory that are perfectly matched to the algorithm. For example, TPUs were specifically designed for algorithms like matrix multiplication, which use very little memory relative to FLOPs, resulting in outstanding cost-performance compared to GPUs at the time.
"TPUs succeeded thanks to the rapid growth of deep learning, but if networks had evolved in a different direction or if models had developed structures that TPUs couldn't handle, it could have been an enormous waste."
3. Book Structure: From Theory to Practice, Plus Tutorials
The main structure of the book is as follows. You don't have to read it in order — feel free to jump to whichever section you need.
Part 1: Prerequisites & Foundational Theory
- Introduction to Roofline Analysis How algorithms are speed-limited by computation, communication, and memory.
- Understanding TPUs Properly How TPU architecture and design affect model training and serving.
- Sharding and Parallel Matrix Multiplication How to distribute (shard) models across multiple chips and implement distributed matrix multiplication.
Part 2: Deep Dive into Transformers
-
Complete Transformer Math Practical calculations for FLOPs, parameter counts, KV cache sizes, and more during training and inference.
-
Parallelism Methods for Transformers Various techniques from data, tensor, and pipeline parallelism to expert parallelism.
"Choosing the optimal combination of parallelism for a given number of chips might look simple, but in practice it's more complex than you'd think."
-
Training LLaMA-3 on TPUs A practical guide covering the time, cost, and other considerations for training an actual open-source model on TPUs.
-
Transformer Inference A to Z Key latency issues during inference, KV cache, and large-scale serving design tips.
-
LLaMA-3 TPU Serving Case Study Analysis of cost, speed, and tradeoffs when serving LLaMA-3 on the latest TPU (v5e).
Part 3: Hands-On Tutorials & Profiling
- TPU Code Profiling JAX/XLA stack explanation, with examples of debugging real issues using the TensorBoard profiler.
- TPU Programming with JAX Leveraging JAX's rich parallelism APIs, presented through engaging worked examples.
- Final Summary & References Conclusion, recommended reading, and appendices.
4. Closing Thoughts: Essential Knowledge for the Era
Parallelism and scaling know-how for large models is no longer a specialized skill. This guide crosses the boundaries of hardware, algorithm, and software to provide practical tips for optimizing the balance of efficiency, cost, and performance.
"Even if you win 20% on a performance benchmark, it's meaningless if you lose 20% on roofline efficiency."
Throughout the official documentation, problems are posed for you to think through on your own, so trying out your own parallelism and optimization strategies in practice is also recommended. A living guidebook essential for the front lines of deep learning and AI — give it a try now
