Learnixo
Back to blog
AI Systemsintermediate

Feed-Forward Networks in Transformers

The role of the position-wise FFN in each transformer block, the expand-and-contract design, activation functions, SwiGLU, and why FFN parameters dominate model size.

Asma Hafeez KhanMay 16, 20264 min read
TransformersFeed-ForwardFFNArchitectureInterview
Share:𝕏

What the FFN Does

Every transformer block contains a position-wise feed-forward network applied identically and independently to each token:

FFN(x) = Activation(x · W₁ + b₁) · W₂ + b₂

Dimensions (original Transformer):
  d_model = 512
  d_ff    = 2048  (4× expansion)
  W₁: (d_model, d_ff) = (512, 2048)
  W₂: (d_ff, d_model) = (2048, 512)

Same W₁ and W₂ applied at every position — not shared across layers.

"Position-wise" means: the FFN sees one token's vector at a time. It does not mix information across positions (that's attention's job).


Expand-and-Contract Design

The two linear layers create an expand-then-contract pattern:

Input:       d_model  (e.g., 512)
       ↓ W₁ 
Hidden:      d_ff     (e.g., 2048 — 4× expansion)
       ↓ Activation
       ↓ W₂
Output:      d_model  (e.g., 512)

The expanded hidden layer acts like a large dictionary of d_ff patterns. W₁ maps the input to this pattern space; the activation selects which patterns are active; W₂ reads out the result.


Activation Functions

Original Transformer: ReLU

ReLU(x) = max(0, x)   — simple, sparse (zero for negative inputs)

Modern models: GELU, SwiGLU, GeGLU

GELU(x) = x · Φ(x)   (Φ = standard normal CDF)
         ≈ x · σ(1.702x)  (fast approximation)

SwiGLU (LLaMA, Mistral, PaLM):
  FFN(x) = (x · W₁) ⊙ SiLU(x · Wgate) · W₂

  where SiLU(x) = x · σ(x)
  and ⊙ is element-wise multiplication (gating)

  SwiGLU uses THREE matrices (W₁, Wgate, W₂) instead of two
  d_ff is typically set to 8/3 × d_model to keep total params equal

Code: FFN Variants

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class FFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.gelu(self.w1(x)))

class SwiGLUFFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1   = nn.Linear(d_model, d_ff, bias=False)
        self.wgate = nn.Linear(d_model, d_ff, bias=False)
        self.w2   = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.silu(self.wgate(x))   # SiLU = x * sigmoid(x)
        return self.w2(self.w1(x) * gate)

Why FFN Parameters Dominate

For a transformer block with d_model=4096, d_ff=16384:

Attention (Q, K, V, O):   4 × (4096 × 4096) = 67M params
FFN (W₁, W₂):             (4096 × 16384) + (16384 × 4096) = 134M params

FFN has 2× the parameters of attention per block.

For a 96-layer model:
  Attention per layer: 67M × 96 = 6.4B
  FFN per layer:      134M × 96 = 12.9B
  FFN ≈ 67% of total parameter count

This is why FFN layers are the primary target for quantisation and pruning — they dominate model size.


FFN as a Key-Value Memory

Geva et al. (2021) showed that FFN layers act like associative memories:

W₁ rows are "keys" — patterns that activate when the input matches
W₂ columns are "values" — what gets added to the output when a key fires

"The capital of France is ___"
  → certain FFN keys fire (matching this pattern)
  → corresponding values contribute "Paris" information
  → the model "retrieves" factual knowledge from FFN weights

This explains why factual knowledge in LLMs is stored primarily in FFN weights, not in attention.


Attention vs FFN: Complementary Roles

Attention:
  Operates across positions
  Mixes information from different tokens
  "Communication" between positions

FFN:
  Operates within positions
  Transforms each token's representation independently
  "Computation" within each position
  Stores factual/associative knowledge

Together, they form the two fundamental operations in each transformer block.


Interview Answer

"The position-wise FFN in each transformer block applies the same two-layer MLP independently to each token: FFN(x) = Activation(x·W₁)·W₂. The hidden dimension is typically 4× the model dimension. Modern models use SwiGLU (LLaMA, Mistral) instead of ReLU — a gated activation with three matrices instead of two. The FFN accounts for roughly 2/3 of a transformer's parameter count. Research shows FFNs act as key-value memories storing factual knowledge, while attention handles cross-position information mixing — the two are complementary."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.