@jacobaustin123.bsky.social
440 followers 38 following 12 posts
Researcher at Google DeepMind. I make LLMs go fast. I also play piano and climb sometimes. Opinions my own
Posts Media Videos Starter Packs
Reposted
jeffdean.bsky.social
Training our most capable Gemini models relies heavily on our JAX software stack+Google's TPU hardware platforms.

If you want to learn more, see this awesome book "How to Scale Your Model":

jax-ml.github.io/scaling-book/

Put together by several of my Google DeepMind colleagues listed below 🎉.
jacobaustin123.bsky.social
Making LLMs run efficiently can feel scary, but scaling isn’t magic, it’s math! We wanted to demystify the “systems view” of LLMs and wrote a little textbook called “How To Scale Your Model” which we’re releasing today. 1/n
jacobaustin123.bsky.social
The book was co-written with @sholtodouglas.bsky.social, @froystig.bsky.social, @levskaya.bsky.social, @reinerpope.bsky.social, Albert Webson, Charlie Chen, Vinay Ramasesh, and Federico Lebron 10/n
jacobaustin123.bsky.social
LLM systems programming is super fun! It's hard to do good ML research without it these days, and you don't need much compute to work on it. I hope this book will make it easier for more people (esp. academics) to work on this stuff 9/n
jacobaustin123.bsky.social
The rest of the book is a set of practical guides: how to write and profile parallel JAX code, and how to apply the previous two sections to real models like LLaMA-3. We also have worked problems at the end of each section if you like homework: jax-ml.github.io/scaling-book... 8/n
jacobaustin123.bsky.social
Now that we’ve talked about training, we need to talk about serving. How expensive should a model be to serve? What kind of latency can we expect? What are prefill and generation? How do we build an efficient inference service? We talk about this here: jax-ml.github.io/scaling-book... 7/n
jacobaustin123.bsky.social
Now for the good stuff! You may have heard of data or tensor parallelism, FSDP or pipelining. But why choose one over the other? Short answer: each adds communication, and the one with the lowest cost depends on the model. Part 5 dives into this: jax-ml.github.io/scaling-book... 6/n
jacobaustin123.bsky.social
5 years ago, there were many ML architectures, but today, there is (mostly) only one. _You should know the Transformer inside and out!_ How many FLOPs or params in LLaMA-3? How expensive is attention vs. a feed-forward block? You'll know after reading jax-ml.github.io/scaling-book... 5/n
jacobaustin123.bsky.social
Scaling an LLM involves distributing — a.k.a. "sharding" — its weights across multiple TPUs. To run it, we have to add cross-chip communication. Part 3 describes the TPU's communication primitives, and simple rules for multiplying sharded matrices: jax-ml.github.io/scaling-book... 4/n
jacobaustin123.bsky.social
A big chunk of this book is dedicated to understanding the hardware that provides those system resources. We emphasize TPUs in this book, but the principles and math can be adapted to GPUs too. Part 2 explains the TPU in detail: jax-ml.github.io/scaling-book... 3/n
jacobaustin123.bsky.social
The secret is to think in terms of basic system resources — compute, memory, and bandwidth — and calculate which one limits our performance. From this we can estimate the cost, runtime, and optimal parallelism strategy for any given LLM: jax-ml.github.io/scaling-book/ 2/n
How To Scale Your Model
Training LLMs often feels like alchemy, but understanding and optimizing the performance of your models doesn't have to. This book aims to demystify the science of scaling language models on TPUs: how...
jax-ml.github.io
jacobaustin123.bsky.social
Making LLMs run efficiently can feel scary, but scaling isn’t magic, it’s math! We wanted to demystify the “systems view” of LLMs and wrote a little textbook called “How To Scale Your Model” which we’re releasing today. 1/n
jacobaustin123.bsky.social
Excited to be here! Hopefully the skies are brighter on this side of the fence. Will be posting research stuff here, mostly