LLaMA Architecture: Modern Decoder Design
How LLaMA and its derivatives (Mistral, Qwen, Phi) improve on the original transformer: RoPE, RMSNorm, SwiGLU, GQA, and grouped query attention.
Why LLaMA Changed the Field
Meta's LLaMA-1 (2023) demonstrated that a 7B model trained on 1T tokens could match or exceed much larger models trained on less data ā validating Chinchilla's compute-optimal training. It used several architectural improvements over the original transformer that are now standard across open-source LLMs.
LLaMA-2 (2023) extended to 70B with 2T tokens. LLaMA-3 (2024) used 8B/70B/400B scales with 15T tokens and a 128k vocabulary.
Core Architectural Changes from Original Transformer
| Feature | Original Transformer | LLaMA | |---|---|---| | Normalization | Post-layer LayerNorm | Pre-layer RMSNorm | | Position encoding | Sinusoidal (learned in GPT) | RoPE | | Attention | Multi-head (MHA) | Grouped query (GQA) | | FFN activation | ReLU | SwiGLU | | FFN bias | Yes | No bias | | Attention bias | Yes | No bias | | Tokenizer | BPE (various) | SentencePiece BPE (LLaMA-1/2), tiktoken (LLaMA-3) |
RMSNorm: Simpler Normalization
Original LayerNorm computes mean and variance across features:
LayerNorm(x) = γ Ć (x - μ) / ā(ϲ + ε) + βRMSNorm (Root Mean Square Layer Normalization) drops the mean centering ā just normalizes by RMS:
RMSNorm(x) = γ Ć x / ā(mean(x²) + ε)import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # γ, no β needed
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute RMS across the feature dimension
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return self.weight * x * rms
# Usage: applied before attention and FFN (pre-norm), not after
# LLaMA: norm ā attention ā residual ā norm ā FFN ā residualPre-norm (applying norm before the sub-layer) instead of post-norm (original transformer) improves training stability at large scale.
SwiGLU Feed-Forward
LLaMA uses SwiGLU instead of GELU FFN:
class LLaMAFeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
# SwiGLU uses 3 matrices instead of 2
# hidden_dim is typically 8/3 Ć dim (rounded to multiple of 256)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # gate
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # projection
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # value
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: SiLU(W1Ā·x) ā W3Ā·x, then project back
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
def compute_ffn_hidden_dim(dim: int, multiple_of: int = 256) -> int:
"""LLaMA's FFN hidden dim calculation."""
hidden_dim = int(8 * dim / 3)
# Round up to multiple_of for hardware efficiency
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
# LLaMA-3-8B: dim=4096 ā hidden_dim=14336 (not 4Ć=16384, but 3.5Ć)
print(compute_ffn_hidden_dim(4096)) # 14336Grouped Query Attention (GQA)
LLaMA-2 (70B) and LLaMA-3 use GQA to reduce KV cache memory:
class GroupedQueryAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = n_heads // n_kv_heads # How many Q heads per KV head
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
batch, seq_len, _ = x.shape
# Project to Q, K, V
xq = self.wq(x).view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to queries and keys
xq = apply_rotary_emb(xq, freqs_cos, freqs_sin)
xk = apply_rotary_emb(xk, freqs_cos, freqs_sin)
# Repeat KV heads to match number of Q heads
xk = xk.repeat_interleave(self.n_rep, dim=1) # (batch, n_heads, seq, head_dim)
xv = xv.repeat_interleave(self.n_rep, dim=1)
# Scaled dot-product attention (causal)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
# LLaMA-3-8B: n_heads=32, n_kv_heads=8 ā 4Ć less KV cache memory vs MHAComplete LLaMA Layer
class LLaMABlock(nn.Module):
def __init__(self, dim: int, n_heads: int, n_kv_heads: int, ffn_hidden: int):
super().__init__()
head_dim = dim // n_heads
self.attention_norm = RMSNorm(dim)
self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads, head_dim)
self.ffn_norm = RMSNorm(dim)
self.feed_forward = LLaMAFeedForward(dim, ffn_hidden)
def forward(self, x: torch.Tensor, freqs_cos, freqs_sin) -> torch.Tensor:
# Pre-norm attention with residual
h = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin)
# Pre-norm FFN with residual
out = h + self.feed_forward(self.ffn_norm(h))
return outThe pre-norm pattern (norm ā sublayer ā residual add) is more stable than the original post-norm (sublayer ā residual add ā norm) for deep networks.
LLaMA-3 Model Configurations
# LLaMA-3-8B configuration
llama3_8b_config = {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8, # GQA: 32 Q heads, 8 KV heads
"vocab_size": 128256, # Larger vocabulary than LLaMA-2
"ffn_hidden": 14336, # 8/3 Ć 4096 rounded to multiple of 256
"rope_base": 500000, # Higher base for longer context
"max_seq_len": 8192, # 128k with extended context fine-tuning
"norm_eps": 1e-5,
}
# Parameter count estimate
def count_params(config: dict) -> dict:
d = config["dim"]
v = config["vocab_size"]
L = config["n_layers"]
h = config["n_heads"]
kv = config["n_kv_heads"]
hd = d // h
ffn = config["ffn_hidden"]
embedding = v * d
attention_per_layer = (h * hd * d) + (kv * hd * d) * 2 + (h * hd * d) # WQ + WK + WV + WO
ffn_per_layer = d * ffn * 3 # W1, W2, W3 (SwiGLU)
norm_per_layer = d * 2 # Two RMSNorms per layer
total = embedding + L * (attention_per_layer + ffn_per_layer + norm_per_layer)
total += d # Final norm
# Note: LLaMA-3 ties embedding weights with output projection
return {
"embedding": embedding,
"attention_all_layers": L * attention_per_layer,
"ffn_all_layers": L * ffn_per_layer,
"total_billions": total / 1e9,
}
counts = count_params(llama3_8b_config)
print(f"Total parameters: {counts['total_billions']:.1f}B")Key Derivatives
Mistral-7B: Uses sliding window attention (each token attends to 4096 neighbors only) and GQA. Competitive with LLaMA-2-13B at less than half the parameters.
Mixtral-8Ć7B: Mistral architecture with Mixture of Experts FFN ā 8 experts, 2 active per token. 46.7B total parameters, 12.9B active parameters per token.
Qwen-2: Alibaba's model family. Similar architecture to LLaMA-3 with larger vocabulary (152k) and additional training data in Chinese.
Phi-3: Microsoft's small models (3.8B, 7B). Same decoder architecture, focus on high-quality training data rather than scale.
The LLaMA architecture (RoPE + RMSNorm + SwiGLU + GQA + pre-norm) has become the de facto standard for open-source decoder-only transformers.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.