Tian Jin
@tjin.bsky.social
57 followers 78 following 35 posts
PhD student @MIT CSAIL
Posts Media Videos Starter Packs
Reposted by Tian Jin
gkdziugaite.bsky.social
Excited to share our research on what matters in sparse LLM pre-training. Stop by our poster @ ICLR 🗓️ April 24th session #2.
tjin.bsky.social
📣 The Journey Matters: Our #ICLR2025 paper shows how to pretrain sparse LLMs with half the size of dense LLMs while maintaining quality. We found that the average parameter count during sparse pre-training predicts quality, not final size. An MIT/Rice/Google/ISTA collab 🧵 1/N
Reposted by Tian Jin
suvinay.bsky.social
Scaling Laws provide a valuable lens in guiding model design and computational budgets. Our recent work extends this lens to the realm of _fine-grained_ sparsity. Check out our #ICLR2025 paper, and the thread below from lead-author @tjin.bsky.social summarizing our findings.
tjin.bsky.social
📣 The Journey Matters: Our #ICLR2025 paper shows how to pretrain sparse LLMs with half the size of dense LLMs while maintaining quality. We found that the average parameter count during sparse pre-training predicts quality, not final size. An MIT/Rice/Google/ISTA collab 🧵 1/N
Reposted by Tian Jin
roydanroy.bsky.social
Tian and Karolina and team are at ICLR. Come say hi.
tjin.bsky.social
📣 The Journey Matters: Our #ICLR2025 paper shows how to pretrain sparse LLMs with half the size of dense LLMs while maintaining quality. We found that the average parameter count during sparse pre-training predicts quality, not final size. An MIT/Rice/Google/ISTA collab 🧵 1/N
tjin.bsky.social
Please visit us at Poster Session #2 on April 24th. Looking forward to meeting you all at ICLR! 9/N
tjin.bsky.social
We're releasing 4 pairs of sparse/dense LLMs (~1B parameters) with matching average parameter counts and identical training tokens, at sparsity levels from 20-80% (see 2nd post for code/models). 8/N
tjin.bsky.social
Key insight: sparse pre-training decouples average parameter count (which governs quality) from final parameter count (which determines inference costs), enabling a new Pareto frontier on the training cost vs. inference efficiency trade-off. 7/N
tjin.bsky.social
Our results validate this unified scaling law across 30 LLM pre-training runs, spanning models from 58M to 468M parameters with up to 80% final sparsity, trained on up to 20x the Chinchilla-optimal token budget. 6/N
tjin.bsky.social
Building off this observation, we extend the Chinchilla scaling law to account for both sparse/dense pre-training by using average param count during pre-training as the model size term. 5/N
tjin.bsky.social
Surprisingly, this pair of models achieve matching quality (eval ppl)! We tested this on models with 162M starting parameters and 20-80% final sparsity. The results are consistent: sparse and dense models with the same average param count reach the matching final eval loss. 4/N
tjin.bsky.social
Consider two parameter count vs. training step curves, w/ an equivalent area under the curve, ie., training FLOPs. Solid line = dense pre-training, dashed line = sparse pre-training w/ gradual pruning. While they differ in final param count, they match in average param count. 3/N
tjin.bsky.social
This is joint work with
Ahmed Imtiaz Humayun,
Utku Evci
@suvinay.bsky.social
Amir Yazdanbakhsh,
@dalistarh.bsky.social
@gkdziugaite.bsky.social

Project/Code/Models: sparsellm.com
Paper: arxiv.org/abs/2501.12486
Session: April 24 Poster Session #2 (Hall 3 + Hall 2B #342)
2/N
The Journey Matters: Average Parameter Count over Pre-training Unifies Sparse and Dense Scaling Laws
sparsellm.com
tjin.bsky.social
📣 The Journey Matters: Our #ICLR2025 paper shows how to pretrain sparse LLMs with half the size of dense LLMs while maintaining quality. We found that the average parameter count during sparse pre-training predicts quality, not final size. An MIT/Rice/Google/ISTA collab 🧵 1/N
tjin.bsky.social
Our work adds to an emerging line of work we call Asynchronous Decoding. Unlike synchronous (parallel) decoding (eg spec dec) that decodes one contiguous chunk in parallel, async dec decodes multiple non-contiguous chunks in parallel, letting LLMs jump ahead during decoding. 13/N
tjin.bsky.social
To make this all work efficiently, we designed a high-performance interpreter that acts on Pasta-Lang annotations to orchestrate asynchronous decoding on-the-fly during LLM decoding. 12/N
tjin.bsky.social
The quality-speedup trade-off keeps improving with more training - showing no signs of saturation! We took 4 snapshots at different points of preference optimization (10% Round 1, 100% R1, 10% R2, 60% R2). As we train more, this trade-off improves toward the optimal top-right corner. 11/N
tjin.bsky.social
We show that PASTA Pareto-dominates all existing async decoding methods! We achieve geometric mean speedups ranging from 1.21× to 1.93× with corresponding quality changes of +2.2% to -7.1%, measured by length-controlled win rates against sequential decoding baseline. 10/N
tjin.bsky.social
We then use these scored examples for preference optimization - teaching the model to generate responses that are both fast and high quality. A quality weight hyperparameter λ lets us tune which aspect (quality vs speed) to prioritize more.
tjin.bsky.social
Stage 2: This is where it gets interesting! For each instruction prompt, we sample multiple Pasta-annotated responses and score them based on:
- Decoding latency (how fast and parallel is the decoding?)
- Response quality (evaluated by another LLM)
8/N
tjin.bsky.social
Stage 1: We first prompt the Gemini model to annotate instruction-following responses with Pasta-Lang. We then finetune our base LLM on this dataset to learn the basic syntax and semantics of Pasta-Lang annotations. 7/N
tjin.bsky.social
How do we train LLMs to do this? Through a two-stage training process that requires less than 10 human annotations! 6/N
tjin.bsky.social
<sync/> tag signals when subsequent decoding requires async decoded chunks. At this point, the interpreter pauses to wait for all async decoding to complete before proceeding, ensuring correctness when dependencies involving async decoded chunks exist. 5/N
tjin.bsky.social
To enable this, we introduce PASTA (PArallel STructure Annotation)-LANG tags and interpreter: <promise/> tags are placeholders for semantically independent chunks, and <async> tags wrap each such chunk, which the interpreter decodes asynchronously in parallel to each other. 4/N
tjin.bsky.social
We developed and evaluated a suite of such LLMs capable of asynchronously parallel decoding on a benchmark of 805 prompts from AlpacaEval. One shows 1.46x geometric mean speedup at a small quality drop of 1.3%, measured by length-controlled win rates. 3/N
tjin.bsky.social
In the figure above, when computing a line segment's length, extracting coordinates and recalling the formula are two semantically independent chunks. Our system trains the LLM to identify this and decodes both chunks asynchronously in parallel, then sync for the final calculation! 2/N