Feed-Forward Networks in Transformers
Understand the position-wise feed-forward network (FFN) in transformer layers: its role, architecture, activation functions, and how it differs from attention.
What the FFN Does
Every transformer layer has two sub-layers: multi-head attention and a position-wise feed-forward network (FFN). Attention handles token-to-token relationships. The FFN processes each token independently โ same weights applied at every position.
The FFN is where most of the transformer's learned knowledge is stored. Factual recall, world knowledge, and language patterns all live primarily in FFN weights, not attention weights.
Architecture
Input: x โ R^(seq_len ร d_model)
FFN(x) = activation(x ยท Wโ + bโ) ยท Wโ + bโ
Wโ: d_model โ d_ff (expansion)
Wโ: d_ff โ d_model (projection)
d_ff is typically 4 ร d_modelThe expansion ratio of 4ร is the standard from the original "Attention is All You Need" paper. GPT-3 uses d_model=12288, d_ff=49152.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
return self.w2(self.dropout(F.gelu(self.w1(x))))Activation Functions
ReLU (original paper): Simple, fast. max(0, x). Problem: dead neurons โ once a neuron outputs 0, it stops learning.
GELU (GPT-2+): Gaussian Error Linear Unit. x ยท ฮฆ(x) where ฮฆ is the CDF of the standard normal. Smoother than ReLU, empirically better on language tasks.
SwiGLU (LLaMA, PaLM): A gated variant that replaces the single expansion layer with two parallel projections:
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: silu(W1 x) โ W3 x
return self.w2(F.silu(self.w1(x)) * self.w3(x))SwiGLU requires a slightly smaller d_ff (typically 8/3 ร d_model instead of 4ร) to keep parameter count comparable. LLaMA uses this.
Position-Wise: Why It Matters
"Position-wise" means the same FFN is applied independently to each token position. Token at position 5 and token at position 42 go through the same weights, with no information shared between them in the FFN.
This is unlike attention, which explicitly computes relationships between all pairs of tokens. The FFN is purely a function of the single token's representation โ it refines that representation using learned patterns.
# FFN processes each position independently โ equivalent to:
for pos in range(seq_len):
output[pos] = ffn(x[pos]) # No dependency on x[other_pos]
# In practice, this is done in parallel with matrix multiply:
output = ffn(x) # (batch, seq_len, d_model) โ all positions at onceFFN as Associative Memory
Research (Geva et al., 2021) shows that FFN layers behave like key-value memories:
- Wโ rows are "keys" โ patterns that activate for certain input patterns
- Wโ columns are "values" โ what gets added to the residual stream when the key fires
A specific neuron in the FFN expansion might fire strongly for all tokens related to "capital cities" โ and the corresponding Wโ column adds a direction in embedding space that shifts outputs toward city-related predictions.
This explains why:
- Factual edits (model editing, ROME) target FFN weights
- Fine-tuning on new facts primarily updates FFN weights
- Larger
d_ff(more neurons) generally means more factual recall capacity
Parameter Count
d_model = 768
d_ff = 3072 # 4ร expansion
# Parameters per FFN layer
w1_params = d_model * d_ff # 768 ร 3072 = 2,359,296
w2_params = d_ff * d_model # 3072 ร 768 = 2,359,296
ffn_params = w1_params + w2_params # ~4.7M per layer
# GPT-2 (12 layers)
total_ffn = ffn_params * 12 # ~56M of the 117M total params
# In GPT models, FFN typically holds ~2/3 of all parameters
# (the other third is in attention projection matrices)Mixture of Experts FFN
Scaling FFNs with Mixture of Experts (MoE) replaces one large FFN with N smaller FFNs (experts), with a router selecting a subset (typically 2) for each token:
class MoEFFN(nn.Module):
def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
super().__init__()
self.experts = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(num_experts)])
self.router = nn.Linear(d_model, num_experts)
self.top_k = top_k
def forward(self, x):
# x: (batch, seq_len, d_model)
scores = self.router(x) # (batch, seq_len, num_experts)
weights, indices = torch.topk(scores, self.top_k, dim=-1)
weights = F.softmax(weights, dim=-1)
output = torch.zeros_like(x)
for k in range(self.top_k):
expert_idx = indices[..., k] # which expert
expert_w = weights[..., k:k+1] # its weight
for i, expert in enumerate(self.experts):
mask = (expert_idx == i)
if mask.any():
output[mask] += expert_w[mask] * expert(x[mask])
return outputMoE increases model capacity without proportionally increasing compute โ each token uses only 2 of 8 experts. Mixtral-8x7B and GPT-4 (reportedly) use this architecture.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.