Learnixo
Back to blog
AI Systemsintermediate

Scaling Laws: Predicting Model Performance

How Chinchilla and OpenAI scaling laws relate model parameters, training tokens, and compute budget to loss. Use scaling laws to make optimal training decisions.

Asma Hafeez KhanMay 16, 20266 min read
TransformersScaling LawsTrainingResearch
Share:š•

What Scaling Laws Tell Us

Scaling laws are empirical relationships between training loss and three variables:

  • N — number of model parameters
  • D — number of training tokens
  • C — compute budget (measured in FLOPs)

The key finding (Kaplan et al., 2020; Hoffmann et al., 2022): test loss follows smooth power-law relationships with N and D. This means you can predict how well a model will perform before training it — given enough smaller training runs to fit the curve.


OpenAI Scaling Laws (Kaplan et al., 2020)

The original OpenAI paper found that, holding compute fixed, loss scales as:

L(N, D) = (N_c / N)^α_N + (D_c / D)^α_D + L_āˆž

where:
  N_c, D_c = critical parameter counts/tokens where model stops improving
  α_N ā‰ˆ 0.076
  α_D ā‰ˆ 0.095
  L_āˆž = irreducible loss (entropy of the data)

An important result: model size matters more than data size at the time. The recommended strategy was to train very large models on relatively less data.

This led to GPT-3 (175B parameters) being trained on ~300B tokens.


Chinchilla Laws (Hoffmann et al., 2022): Compute-Optimal Training

DeepMind found that Kaplan's experiments used too little data. Training on more tokens for smaller models gave equal loss at lower inference cost.

Chinchilla optimal: For a given compute budget C, the compute-optimal model size N* and token count D* satisfy:

N* āˆ C^0.5
D* āˆ C^0.5
D* / N* ā‰ˆ 20 tokens per parameter
Python
def chinchilla_optimal(compute_flops: float) -> dict:
    """
    Estimate compute-optimal model size and token count.
    compute_flops: total training FLOPs (approximately 6 Ɨ N Ɨ D)
    """
    # From Chinchilla paper coefficients
    # L = E + A/N^α + B/D^β
    # Optimal: N* ā‰ˆ (compute)^0.5 Ɨ constant, D* ā‰ˆ 20 Ɨ N*

    # Rough approximation: 6 FLOPs per token per parameter (forward + backward)
    # C ā‰ˆ 6 Ɨ N Ɨ D → N Ɨ D ā‰ˆ C / 6
    # At chinchilla optimal: D = 20N → N Ɨ 20N = C/6 → N² = C/120 → N = sqrt(C/120)

    flops_per_token_per_param = 6.0
    tokens_per_param = 20.0

    N_optimal = (compute_flops / (flops_per_token_per_param * tokens_per_param)) ** 0.5
    D_optimal = tokens_per_param * N_optimal

    return {
        "optimal_params": N_optimal,
        "optimal_tokens": D_optimal,
        "params_billions": N_optimal / 1e9,
        "tokens_billions": D_optimal / 1e9,
        "tokens_per_param": D_optimal / N_optimal,
    }

# Example: 1e23 FLOPs (roughly what GPT-3 used)
result = chinchilla_optimal(1e23)
print(f"Optimal model size: {result['params_billions']:.1f}B parameters")
print(f"Optimal token count: {result['tokens_billions']:.1f}B tokens")
print(f"Tokens per parameter: {result['tokens_per_param']:.1f}")
# Chinchilla-70B used 1.4T tokens for 70B params ā‰ˆ 20 tokens/param

GPT-3 (175B params, 300B tokens) is over-parameterized under Chinchilla: given its compute, the optimal model is much smaller trained on far more data.


The Chinchilla Model

Hoffmann et al. fit the loss function:

L(N, D) = E + A / N^α + B / D^β

Fitted values:
  E = 1.69  (irreducible loss — entropy of the data)
  A = 406.4
  B = 410.7
  α = 0.34
  β = 0.28
Python
def predict_loss(
    N: float,  # parameters
    D: float,  # tokens
    E: float = 1.69,
    A: float = 406.4,
    B: float = 410.7,
    alpha: float = 0.34,
    beta: float = 0.28,
) -> float:
    """Predict cross-entropy loss using Chinchilla scaling law."""
    return E + A / (N ** alpha) + B / (D ** beta)

# Compare GPT-3 vs Chinchilla-equivalent:
gpt3_loss = predict_loss(N=175e9, D=300e9)
chinchilla_70b_loss = predict_loss(N=70e9, D=1.4e12)

print(f"GPT-3 predicted loss: {gpt3_loss:.3f}")
print(f"Chinchilla-70B predicted loss: {chinchilla_70b_loss:.3f}")
# Chinchilla-70B achieves lower (better) loss with ~4Ɨ fewer parameters

Why This Changed the Field

Before Chinchilla:

  • GPT-3, PaLM, Gopher trained very large models on relatively little data
  • Industry believed "bigger is better" with fixed compute

After Chinchilla:

  • LLaMA-1 (65B on 1.4T tokens), LLaMA-2 (70B on 2T tokens): smaller, well-trained models
  • LLaMA-3-8B trained on 15T tokens — 1875 tokens per parameter (far beyond Chinchilla optimal)
  • Modern practice: train smaller models on much more data because inference cost matters

The insight: A 7B model trained on 2T tokens is cheaper to run at inference than a 70B model trained on 200B tokens, and may match it on many benchmarks.


Scaling Laws in Practice

Using smaller models to predict large model performance

Python
import numpy as np
from scipy.optimize import curve_fit

def power_law(n, a, alpha):
    """Simple power law: loss = a Ɨ N^(-alpha)"""
    return a * n ** (-alpha)

# Fit on small model runs
model_sizes = [125e6, 350e6, 760e6, 1.3e9, 2.7e9]
observed_losses = [3.12, 2.85, 2.67, 2.54, 2.41]  # Example validation losses

params, _ = curve_fit(power_law, model_sizes, observed_losses)
a_fit, alpha_fit = params

# Predict 70B model loss
predicted_70b_loss = power_law(70e9, a_fit, alpha_fit)
print(f"Fitted power law: loss = {a_fit:.2f} Ɨ N^(-{alpha_fit:.3f})")
print(f"Predicted 70B loss: {predicted_70b_loss:.3f}")

Estimating compute requirements

Python
def estimate_training_compute(
    num_params: float,
    num_tokens: float,
    forward_flops_per_token: float = 6.0,  # Standard approximation: 6 Ɨ params
) -> dict:
    """Estimate total training FLOPs."""
    # Rule of thumb: ~6 FLOPs per token per parameter (forward + backward)
    total_flops = forward_flops_per_token * num_params * num_tokens

    # Convert to GPU-hours (A100: ~312 TFLOPs/s for BF16)
    a100_tflops_per_sec = 312e12
    mfu = 0.40  # Model FLOP utilization: typically 35-50%
    effective_tflops = a100_tflops_per_sec * mfu

    seconds = total_flops / effective_tflops
    gpu_hours = seconds / 3600

    return {
        "total_flops": total_flops,
        "gpu_hours_single_a100": gpu_hours,
        "gpu_days_single_a100": gpu_hours / 24,
    }

# LLaMA-3-8B: 8B params, 15T tokens
compute = estimate_training_compute(8e9, 15e12)
print(f"Training FLOPs: {compute['total_flops']:.2e}")
print(f"Single A100 GPU-hours: {compute['gpu_hours_single_a100']:.0f}")
print(f"With 1024 A100s: {compute['gpu_days_single_a100']/1024:.0f} days")

Beyond Loss: Task Performance Scaling

Scaling laws for perplexity/loss don't always translate directly to task performance. Some tasks show:

  • Smooth scaling: Translation, summarization, code completion — gradual improvement with scale
  • Emergent capabilities: Chain-of-thought reasoning, arithmetic — appear suddenly at a threshold scale
Emergent behavior: a capability that is absent in small models and then
abruptly appears as scale increases, with no gradual intermediate

This makes scaling law extrapolation for specific capabilities less reliable than for overall loss. A model that achieves predicted loss may still fail on specific reasoning tasks that only emerge at larger scale.


Key Takeaways for Practitioners

  1. Data efficiency: 20 tokens per parameter (Chinchilla) is the compute-optimal starting point. Modern practice trains beyond this for better inference-time efficiency.

  2. Fit curves before committing: Run 5–10 small training runs to fit a power law before training large. This is standard at ML research labs.

  3. Loss predicts loss, not tasks: Be cautious extrapolating scaling laws to task performance — emergent behaviors break smooth scaling.

  4. Inference cost matters: Chinchilla optimal for training is not optimal for deployment. If you'll serve millions of requests, a smaller, better-trained model wins on total cost.

  5. Compute budget allocation: Given a fixed compute budget, the optimal split is roughly equal between model size and data size (both scale as C^0.5).

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.