Learnixo
Back to blog
AI Systemsintermediate

Feed-Forward Networks in Transformers

Understand the position-wise feed-forward network (FFN) in transformer layers: its role, architecture, activation functions, and how it differs from attention.

Asma Hafeez KhanMay 16, 20264 min read
TransformersFFNArchitectureDeep Learning
Share:๐•

What the FFN Does

Every transformer layer has two sub-layers: multi-head attention and a position-wise feed-forward network (FFN). Attention handles token-to-token relationships. The FFN processes each token independently โ€” same weights applied at every position.

The FFN is where most of the transformer's learned knowledge is stored. Factual recall, world knowledge, and language patterns all live primarily in FFN weights, not attention weights.


Architecture

Input: x โˆˆ R^(seq_len ร— d_model)

FFN(x) = activation(x ยท Wโ‚ + bโ‚) ยท Wโ‚‚ + bโ‚‚

Wโ‚: d_model โ†’ d_ff      (expansion)
Wโ‚‚: d_ff โ†’ d_model      (projection)
d_ff is typically 4 ร— d_model

The expansion ratio of 4ร— is the standard from the original "Attention is All You Need" paper. GPT-3 uses d_model=12288, d_ff=49152.

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        return self.w2(self.dropout(F.gelu(self.w1(x))))

Activation Functions

ReLU (original paper): Simple, fast. max(0, x). Problem: dead neurons โ€” once a neuron outputs 0, it stops learning.

GELU (GPT-2+): Gaussian Error Linear Unit. x ยท ฮฆ(x) where ฮฆ is the CDF of the standard normal. Smoother than ReLU, empirically better on language tasks.

SwiGLU (LLaMA, PaLM): A gated variant that replaces the single expansion layer with two parallel projections:

Python
class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # gate
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: silu(W1 x) โŠ™ W3 x
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

SwiGLU requires a slightly smaller d_ff (typically 8/3 ร— d_model instead of 4ร—) to keep parameter count comparable. LLaMA uses this.


Position-Wise: Why It Matters

"Position-wise" means the same FFN is applied independently to each token position. Token at position 5 and token at position 42 go through the same weights, with no information shared between them in the FFN.

This is unlike attention, which explicitly computes relationships between all pairs of tokens. The FFN is purely a function of the single token's representation โ€” it refines that representation using learned patterns.

Python
# FFN processes each position independently โ€” equivalent to:
for pos in range(seq_len):
    output[pos] = ffn(x[pos])  # No dependency on x[other_pos]

# In practice, this is done in parallel with matrix multiply:
output = ffn(x)  # (batch, seq_len, d_model) โ€” all positions at once

FFN as Associative Memory

Research (Geva et al., 2021) shows that FFN layers behave like key-value memories:

  • Wโ‚ rows are "keys" โ€” patterns that activate for certain input patterns
  • Wโ‚‚ columns are "values" โ€” what gets added to the residual stream when the key fires

A specific neuron in the FFN expansion might fire strongly for all tokens related to "capital cities" โ€” and the corresponding Wโ‚‚ column adds a direction in embedding space that shifts outputs toward city-related predictions.

This explains why:

  • Factual edits (model editing, ROME) target FFN weights
  • Fine-tuning on new facts primarily updates FFN weights
  • Larger d_ff (more neurons) generally means more factual recall capacity

Parameter Count

Python
d_model = 768
d_ff = 3072  # 4ร— expansion

# Parameters per FFN layer
w1_params = d_model * d_ff       # 768 ร— 3072 = 2,359,296
w2_params = d_ff * d_model       # 3072 ร— 768 = 2,359,296
ffn_params = w1_params + w2_params  # ~4.7M per layer

# GPT-2 (12 layers)
total_ffn = ffn_params * 12  # ~56M of the 117M total params

# In GPT models, FFN typically holds ~2/3 of all parameters
# (the other third is in attention projection matrices)

Mixture of Experts FFN

Scaling FFNs with Mixture of Experts (MoE) replaces one large FFN with N smaller FFNs (experts), with a router selecting a subset (typically 2) for each token:

Python
class MoEFFN(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(num_experts)])
        self.router = nn.Linear(d_model, num_experts)
        self.top_k = top_k

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        scores = self.router(x)                              # (batch, seq_len, num_experts)
        weights, indices = torch.topk(scores, self.top_k, dim=-1)
        weights = F.softmax(weights, dim=-1)

        output = torch.zeros_like(x)
        for k in range(self.top_k):
            expert_idx = indices[..., k]     # which expert
            expert_w = weights[..., k:k+1]  # its weight
            for i, expert in enumerate(self.experts):
                mask = (expert_idx == i)
                if mask.any():
                    output[mask] += expert_w[mask] * expert(x[mask])
        return output

MoE increases model capacity without proportionally increasing compute โ€” each token uses only 2 of 8 experts. Mixtral-8x7B and GPT-4 (reportedly) use this architecture.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.