Scaling Laws: Predicting Model Performance
How Chinchilla and OpenAI scaling laws relate model parameters, training tokens, and compute budget to loss. Use scaling laws to make optimal training decisions.
What Scaling Laws Tell Us
Scaling laws are empirical relationships between training loss and three variables:
- N ā number of model parameters
- D ā number of training tokens
- C ā compute budget (measured in FLOPs)
The key finding (Kaplan et al., 2020; Hoffmann et al., 2022): test loss follows smooth power-law relationships with N and D. This means you can predict how well a model will perform before training it ā given enough smaller training runs to fit the curve.
OpenAI Scaling Laws (Kaplan et al., 2020)
The original OpenAI paper found that, holding compute fixed, loss scales as:
L(N, D) = (N_c / N)^α_N + (D_c / D)^α_D + L_ā
where:
N_c, D_c = critical parameter counts/tokens where model stops improving
α_N ā 0.076
α_D ā 0.095
L_ā = irreducible loss (entropy of the data)An important result: model size matters more than data size at the time. The recommended strategy was to train very large models on relatively less data.
This led to GPT-3 (175B parameters) being trained on ~300B tokens.
Chinchilla Laws (Hoffmann et al., 2022): Compute-Optimal Training
DeepMind found that Kaplan's experiments used too little data. Training on more tokens for smaller models gave equal loss at lower inference cost.
Chinchilla optimal: For a given compute budget C, the compute-optimal model size N* and token count D* satisfy:
N* ā C^0.5
D* ā C^0.5
D* / N* ā 20 tokens per parameterdef chinchilla_optimal(compute_flops: float) -> dict:
"""
Estimate compute-optimal model size and token count.
compute_flops: total training FLOPs (approximately 6 Ć N Ć D)
"""
# From Chinchilla paper coefficients
# L = E + A/N^α + B/D^β
# Optimal: N* ā (compute)^0.5 Ć constant, D* ā 20 Ć N*
# Rough approximation: 6 FLOPs per token per parameter (forward + backward)
# C ā 6 Ć N Ć D ā N Ć D ā C / 6
# At chinchilla optimal: D = 20N ā N Ć 20N = C/6 ā N² = C/120 ā N = sqrt(C/120)
flops_per_token_per_param = 6.0
tokens_per_param = 20.0
N_optimal = (compute_flops / (flops_per_token_per_param * tokens_per_param)) ** 0.5
D_optimal = tokens_per_param * N_optimal
return {
"optimal_params": N_optimal,
"optimal_tokens": D_optimal,
"params_billions": N_optimal / 1e9,
"tokens_billions": D_optimal / 1e9,
"tokens_per_param": D_optimal / N_optimal,
}
# Example: 1e23 FLOPs (roughly what GPT-3 used)
result = chinchilla_optimal(1e23)
print(f"Optimal model size: {result['params_billions']:.1f}B parameters")
print(f"Optimal token count: {result['tokens_billions']:.1f}B tokens")
print(f"Tokens per parameter: {result['tokens_per_param']:.1f}")
# Chinchilla-70B used 1.4T tokens for 70B params ā 20 tokens/paramGPT-3 (175B params, 300B tokens) is over-parameterized under Chinchilla: given its compute, the optimal model is much smaller trained on far more data.
The Chinchilla Model
Hoffmann et al. fit the loss function:
L(N, D) = E + A / N^α + B / D^β
Fitted values:
E = 1.69 (irreducible loss ā entropy of the data)
A = 406.4
B = 410.7
α = 0.34
β = 0.28def predict_loss(
N: float, # parameters
D: float, # tokens
E: float = 1.69,
A: float = 406.4,
B: float = 410.7,
alpha: float = 0.34,
beta: float = 0.28,
) -> float:
"""Predict cross-entropy loss using Chinchilla scaling law."""
return E + A / (N ** alpha) + B / (D ** beta)
# Compare GPT-3 vs Chinchilla-equivalent:
gpt3_loss = predict_loss(N=175e9, D=300e9)
chinchilla_70b_loss = predict_loss(N=70e9, D=1.4e12)
print(f"GPT-3 predicted loss: {gpt3_loss:.3f}")
print(f"Chinchilla-70B predicted loss: {chinchilla_70b_loss:.3f}")
# Chinchilla-70B achieves lower (better) loss with ~4Ć fewer parametersWhy This Changed the Field
Before Chinchilla:
- GPT-3, PaLM, Gopher trained very large models on relatively little data
- Industry believed "bigger is better" with fixed compute
After Chinchilla:
- LLaMA-1 (65B on 1.4T tokens), LLaMA-2 (70B on 2T tokens): smaller, well-trained models
- LLaMA-3-8B trained on 15T tokens ā 1875 tokens per parameter (far beyond Chinchilla optimal)
- Modern practice: train smaller models on much more data because inference cost matters
The insight: A 7B model trained on 2T tokens is cheaper to run at inference than a 70B model trained on 200B tokens, and may match it on many benchmarks.
Scaling Laws in Practice
Using smaller models to predict large model performance
import numpy as np
from scipy.optimize import curve_fit
def power_law(n, a, alpha):
"""Simple power law: loss = a Ć N^(-alpha)"""
return a * n ** (-alpha)
# Fit on small model runs
model_sizes = [125e6, 350e6, 760e6, 1.3e9, 2.7e9]
observed_losses = [3.12, 2.85, 2.67, 2.54, 2.41] # Example validation losses
params, _ = curve_fit(power_law, model_sizes, observed_losses)
a_fit, alpha_fit = params
# Predict 70B model loss
predicted_70b_loss = power_law(70e9, a_fit, alpha_fit)
print(f"Fitted power law: loss = {a_fit:.2f} Ć N^(-{alpha_fit:.3f})")
print(f"Predicted 70B loss: {predicted_70b_loss:.3f}")Estimating compute requirements
def estimate_training_compute(
num_params: float,
num_tokens: float,
forward_flops_per_token: float = 6.0, # Standard approximation: 6 Ć params
) -> dict:
"""Estimate total training FLOPs."""
# Rule of thumb: ~6 FLOPs per token per parameter (forward + backward)
total_flops = forward_flops_per_token * num_params * num_tokens
# Convert to GPU-hours (A100: ~312 TFLOPs/s for BF16)
a100_tflops_per_sec = 312e12
mfu = 0.40 # Model FLOP utilization: typically 35-50%
effective_tflops = a100_tflops_per_sec * mfu
seconds = total_flops / effective_tflops
gpu_hours = seconds / 3600
return {
"total_flops": total_flops,
"gpu_hours_single_a100": gpu_hours,
"gpu_days_single_a100": gpu_hours / 24,
}
# LLaMA-3-8B: 8B params, 15T tokens
compute = estimate_training_compute(8e9, 15e12)
print(f"Training FLOPs: {compute['total_flops']:.2e}")
print(f"Single A100 GPU-hours: {compute['gpu_hours_single_a100']:.0f}")
print(f"With 1024 A100s: {compute['gpu_days_single_a100']/1024:.0f} days")Beyond Loss: Task Performance Scaling
Scaling laws for perplexity/loss don't always translate directly to task performance. Some tasks show:
- Smooth scaling: Translation, summarization, code completion ā gradual improvement with scale
- Emergent capabilities: Chain-of-thought reasoning, arithmetic ā appear suddenly at a threshold scale
Emergent behavior: a capability that is absent in small models and then
abruptly appears as scale increases, with no gradual intermediateThis makes scaling law extrapolation for specific capabilities less reliable than for overall loss. A model that achieves predicted loss may still fail on specific reasoning tasks that only emerge at larger scale.
Key Takeaways for Practitioners
-
Data efficiency: 20 tokens per parameter (Chinchilla) is the compute-optimal starting point. Modern practice trains beyond this for better inference-time efficiency.
-
Fit curves before committing: Run 5ā10 small training runs to fit a power law before training large. This is standard at ML research labs.
-
Loss predicts loss, not tasks: Be cautious extrapolating scaling laws to task performance ā emergent behaviors break smooth scaling.
-
Inference cost matters: Chinchilla optimal for training is not optimal for deployment. If you'll serve millions of requests, a smaller, better-trained model wins on total cost.
-
Compute budget allocation: Given a fixed compute budget, the optimal split is roughly equal between model size and data size (both scale as C^0.5).
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.