Learnixo
Back to blog
AI Systemsintermediate

LLaMA Architecture: Modern Decoder Design

How LLaMA and its derivatives (Mistral, Qwen, Phi) improve on the original transformer: RoPE, RMSNorm, SwiGLU, GQA, and grouped query attention.

Asma Hafeez KhanMay 16, 20266 min read
TransformersLLaMAArchitectureOpen Source LLM
Share:š•

Why LLaMA Changed the Field

Meta's LLaMA-1 (2023) demonstrated that a 7B model trained on 1T tokens could match or exceed much larger models trained on less data — validating Chinchilla's compute-optimal training. It used several architectural improvements over the original transformer that are now standard across open-source LLMs.

LLaMA-2 (2023) extended to 70B with 2T tokens. LLaMA-3 (2024) used 8B/70B/400B scales with 15T tokens and a 128k vocabulary.


Core Architectural Changes from Original Transformer

| Feature | Original Transformer | LLaMA | |---|---|---| | Normalization | Post-layer LayerNorm | Pre-layer RMSNorm | | Position encoding | Sinusoidal (learned in GPT) | RoPE | | Attention | Multi-head (MHA) | Grouped query (GQA) | | FFN activation | ReLU | SwiGLU | | FFN bias | Yes | No bias | | Attention bias | Yes | No bias | | Tokenizer | BPE (various) | SentencePiece BPE (LLaMA-1/2), tiktoken (LLaMA-3) |


RMSNorm: Simpler Normalization

Original LayerNorm computes mean and variance across features:

LayerNorm(x) = γ Ɨ (x - μ) / √(σ² + ε) + β

RMSNorm (Root Mean Square Layer Normalization) drops the mean centering — just normalizes by RMS:

RMSNorm(x) = γ Ɨ x / √(mean(x²) + ε)
Python
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # γ, no β needed

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute RMS across the feature dimension
        rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return self.weight * x * rms

# Usage: applied before attention and FFN (pre-norm), not after
# LLaMA: norm → attention → residual → norm → FFN → residual

Pre-norm (applying norm before the sub-layer) instead of post-norm (original transformer) improves training stability at large scale.


SwiGLU Feed-Forward

LLaMA uses SwiGLU instead of GELU FFN:

Python
class LLaMAFeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        # SwiGLU uses 3 matrices instead of 2
        # hidden_dim is typically 8/3 Ɨ dim (rounded to multiple of 256)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # gate
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # projection
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # value

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: SiLU(W1Ā·x) āŠ™ W3Ā·x, then project back
        return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))

def compute_ffn_hidden_dim(dim: int, multiple_of: int = 256) -> int:
    """LLaMA's FFN hidden dim calculation."""
    hidden_dim = int(8 * dim / 3)
    # Round up to multiple_of for hardware efficiency
    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
    return hidden_dim

# LLaMA-3-8B: dim=4096 → hidden_dim=14336 (not 4Ɨ=16384, but 3.5Ɨ)
print(compute_ffn_hidden_dim(4096))  # 14336

Grouped Query Attention (GQA)

LLaMA-2 (70B) and LLaMA-3 use GQA to reduce KV cache memory:

Python
class GroupedQueryAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        self.n_rep = n_heads // n_kv_heads  # How many Q heads per KV head

        self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)

    def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        # Project to Q, K, V
        xq = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        xk = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        xv = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE to queries and keys
        xq = apply_rotary_emb(xq, freqs_cos, freqs_sin)
        xk = apply_rotary_emb(xk, freqs_cos, freqs_sin)

        # Repeat KV heads to match number of Q heads
        xk = xk.repeat_interleave(self.n_rep, dim=1)  # (batch, n_heads, seq, head_dim)
        xv = xv.repeat_interleave(self.n_rep, dim=1)

        # Scaled dot-product attention (causal)
        import torch.nn.functional as F
        out = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(out)

# LLaMA-3-8B: n_heads=32, n_kv_heads=8 → 4Ɨ less KV cache memory vs MHA

Complete LLaMA Layer

Python
class LLaMABlock(nn.Module):
    def __init__(self, dim: int, n_heads: int, n_kv_heads: int, ffn_hidden: int):
        super().__init__()
        head_dim = dim // n_heads

        self.attention_norm = RMSNorm(dim)
        self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads, head_dim)
        self.ffn_norm = RMSNorm(dim)
        self.feed_forward = LLaMAFeedForward(dim, ffn_hidden)

    def forward(self, x: torch.Tensor, freqs_cos, freqs_sin) -> torch.Tensor:
        # Pre-norm attention with residual
        h = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin)
        # Pre-norm FFN with residual
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

The pre-norm pattern (norm → sublayer → residual add) is more stable than the original post-norm (sublayer → residual add → norm) for deep networks.


LLaMA-3 Model Configurations

Python
# LLaMA-3-8B configuration
llama3_8b_config = {
    "dim": 4096,
    "n_layers": 32,
    "n_heads": 32,
    "n_kv_heads": 8,      # GQA: 32 Q heads, 8 KV heads
    "vocab_size": 128256,  # Larger vocabulary than LLaMA-2
    "ffn_hidden": 14336,   # 8/3 Ɨ 4096 rounded to multiple of 256
    "rope_base": 500000,   # Higher base for longer context
    "max_seq_len": 8192,   # 128k with extended context fine-tuning
    "norm_eps": 1e-5,
}

# Parameter count estimate
def count_params(config: dict) -> dict:
    d = config["dim"]
    v = config["vocab_size"]
    L = config["n_layers"]
    h = config["n_heads"]
    kv = config["n_kv_heads"]
    hd = d // h
    ffn = config["ffn_hidden"]

    embedding = v * d
    attention_per_layer = (h * hd * d) + (kv * hd * d) * 2 + (h * hd * d)  # WQ + WK + WV + WO
    ffn_per_layer = d * ffn * 3  # W1, W2, W3 (SwiGLU)
    norm_per_layer = d * 2  # Two RMSNorms per layer

    total = embedding + L * (attention_per_layer + ffn_per_layer + norm_per_layer)
    total += d  # Final norm
    # Note: LLaMA-3 ties embedding weights with output projection

    return {
        "embedding": embedding,
        "attention_all_layers": L * attention_per_layer,
        "ffn_all_layers": L * ffn_per_layer,
        "total_billions": total / 1e9,
    }

counts = count_params(llama3_8b_config)
print(f"Total parameters: {counts['total_billions']:.1f}B")

Key Derivatives

Mistral-7B: Uses sliding window attention (each token attends to 4096 neighbors only) and GQA. Competitive with LLaMA-2-13B at less than half the parameters.

Mixtral-8Ɨ7B: Mistral architecture with Mixture of Experts FFN — 8 experts, 2 active per token. 46.7B total parameters, 12.9B active parameters per token.

Qwen-2: Alibaba's model family. Similar architecture to LLaMA-3 with larger vocabulary (152k) and additional training data in Chinese.

Phi-3: Microsoft's small models (3.8B, 7B). Same decoder architecture, focus on high-quality training data rather than scale.

The LLaMA architecture (RoPE + RMSNorm + SwiGLU + GQA + pre-norm) has become the de facto standard for open-source decoder-only transformers.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.