Mixture of Experts: Sparse Scaling
How Mixture of Experts (MoE) scales model capacity without proportionally scaling compute. Covers router mechanisms, load balancing, expert collapse, and models like Mixtral.
The Core Idea
Standard transformers are "dense" ā every parameter is used for every token. A 70B parameter model uses all 70B parameters to process each token.
Mixture of Experts (MoE) replaces the dense FFN layer with N "expert" FFN layers and a router that selects K of them for each token (typically K=2). With N=8 experts, only 2 are active per token ā so active parameters per token is much less than total parameters.
Dense 70B model: uses 70B params per token
MoE 8Ć7B model: uses ~13B params per token (2 of 8 experts active)
but total parameters = ~47B (8 expert FFNs + shared layers)Result: MoE scales capacity (total parameters) without proportionally scaling compute (active parameters per forward pass).
Router Mechanism
The router is a learned linear layer that takes the token representation and produces weights over experts:
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopKRouter(nn.Module):
def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
x: (batch * seq_len, d_model) ā routed per token
Returns: expert weights and indices for top-k experts
"""
# Compute routing scores
logits = self.gate(x) # (batch*seq, num_experts)
# Select top-k experts
top_k_weights, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
# Normalize weights across selected experts with softmax
top_k_weights = F.softmax(top_k_weights, dim=-1)
return top_k_weights, top_k_indices # Both: (batch*seq, top_k)Full MoE FFN Layer
class MoELayer(nn.Module):
def __init__(self, d_model: int, d_ff: int, num_experts: int, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Each expert is an independent FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.SiLU(),
nn.Linear(d_ff, d_model, bias=False),
)
for _ in range(num_experts)
])
self.router = TopKRouter(d_model, num_experts, top_k)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model) # (batch*seq, d_model)
n_tokens = x_flat.shape[0]
# Get routing decisions
weights, indices = self.router(x_flat)
# weights, indices: (n_tokens, top_k)
output = torch.zeros_like(x_flat)
# Process each expert
for expert_idx, expert in enumerate(self.experts):
# Find which tokens and which routing slot use this expert
token_mask, slot_mask = torch.where(indices == expert_idx)
if len(token_mask) == 0:
continue # Skip expert if no tokens routed to it
# Get the tokens for this expert
expert_input = x_flat[token_mask]
# Run expert
expert_output = expert(expert_input)
# Weight the output by the routing weight
expert_weights = weights[token_mask, slot_mask].unsqueeze(-1)
output[token_mask] += expert_weights * expert_output
return output.view(batch, seq_len, d_model)Load Balancing: Preventing Expert Collapse
Without explicit load balancing, the router tends to route all tokens to a small number of experts (expert collapse). This wastes the capacity of unused experts:
class LoadBalancedMoELayer(MoELayer):
def __init__(self, d_model, d_ff, num_experts, top_k=2, aux_loss_weight=0.01):
super().__init__(d_model, d_ff, num_experts, top_k)
self.aux_loss_weight = aux_loss_weight
def compute_aux_loss(
self,
router_logits: torch.Tensor, # (n_tokens, num_experts)
top_k_indices: torch.Tensor, # (n_tokens, top_k)
) -> torch.Tensor:
"""
Switch Transformer auxiliary loss: encourages uniform expert utilization.
Loss = num_experts Ć sum_i(f_i Ć P_i)
where f_i = fraction of tokens routed to expert i
P_i = mean routing probability for expert i
"""
n_tokens = router_logits.shape[0]
# Fraction of tokens routed to each expert
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
expert_mask = expert_mask.sum(dim=1) # (n_tokens, num_experts)
f = expert_mask.mean(dim=0) # (num_experts,)
# Mean routing probability for each expert
P = F.softmax(router_logits, dim=-1).mean(dim=0) # (num_experts,)
# Auxiliary loss: minimize product (high load Ć high prob = collapse)
aux_loss = self.num_experts * (f * P).sum()
return aux_loss
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
logits = self.router.gate(x_flat)
weights, indices = self.router(x_flat)
# Compute auxiliary loss for training
aux_loss = self.compute_aux_loss(logits, indices) * self.aux_loss_weight
output = torch.zeros_like(x_flat)
for expert_idx, expert in enumerate(self.experts):
token_mask, slot_mask = torch.where(indices == expert_idx)
if len(token_mask) == 0:
continue
expert_output = expert(x_flat[token_mask])
expert_weights = weights[token_mask, slot_mask].unsqueeze(-1)
output[token_mask] += expert_weights * expert_output
return output.view(batch, seq_len, d_model), aux_lossExpert Parallelism for Training and Inference
In practice, each expert resides on a different GPU:
GPU 0: Expert 0, Expert 1
GPU 1: Expert 2, Expert 3
GPU 2: Expert 4, Expert 5
GPU 3: Expert 6, Expert 7Tokens are dispatched to GPUs based on the router's decisions (all-to-all communication), processed, then gathered back. This is "expert parallelism" and is how frameworks like Megatron-LM and DeepSpeed implement MoE at scale.
Mixtral-8Ć7B Architecture
Mixtral is a Mistral architecture with MoE FFN:
# Mixtral-8Ć7B configuration
mixtral_config = {
"d_model": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8, # GQA same as Mistral
"num_experts": 8,
"top_k": 2, # 2 of 8 experts active per token
"ffn_hidden": 14336, # Per expert
"vocab_size": 32000,
}
# Parameter breakdown:
# Shared layers (embeddings, attention, norms) per layer
shared_per_layer = (
4096 * 4096 * 4 + # WQ, WK, WV, WO (attention)
4096 * 2 # Two RMSNorms
)
# Expert FFN parameters per layer
expert_params = 8 * (4096 * 14336 * 3) # 8 experts Ć 3 matrices (SwiGLU)
total_params = (
32000 * 4096 + # Embedding
32 * (shared_per_layer + expert_params) +
4096 # Final norm
)
active_params = (
total_params
- 32 * (expert_params - 2 * 4096 * 14336 * 3 // 8) # Subtract inactive 6 experts
)
print(f"Total parameters: {total_params/1e9:.1f}B") # ~46.7B
print(f"Active per token: {active_params/1e9:.1f}B") # ~12.9BMoE vs Dense: Practical Tradeoffs
| | Dense 7B | MoE 8Ć7B (Mixtral) | |---|---|---| | Total parameters | 7B | 46.7B | | Active per token | 7B | 12.9B | | Compute per token | Baseline | ~2Ć baseline | | Memory footprint | ~14GB (bf16) | ~94GB (bf16) | | Quality | Good | Better (larger effective capacity) | | Inference speed (single GPU) | Faster (fits in memory) | Slower (may not fit) | | Multi-GPU inference | Optional | Often required |
MoE excels when:
- You have multiple GPUs for inference (experts distributed across devices)
- Training compute budget is fixed but you want higher model quality
- Tasks benefit from specialization (different experts learn different domains)
MoE is harder when:
- Single-GPU deployment is required (memory doesn't fit)
- Low-latency inference is critical (routing overhead + expert dispatch adds latency)
- Training stability is a priority (load balancing instability can cause training divergence)
Expert Specialization
Research shows that experts in trained MoE models do specialize. In language models, different experts tend to activate for:
- Syntactic structures (verbs, nouns, punctuation patterns)
- Domain-specific content (medical text, code, legal language)
- Position-dependent patterns (beginning of sentences, end of paragraphs)
This emergent specialization is why MoE models often outperform same-compute dense models ā each expert develops a more focused, refined representation for its subspace of the data distribution.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.