Learnixo
Back to blog
AI Systemsintermediate

Mixture of Experts: Sparse Scaling

How Mixture of Experts (MoE) scales model capacity without proportionally scaling compute. Covers router mechanisms, load balancing, expert collapse, and models like Mixtral.

Asma Hafeez KhanMay 16, 20266 min read
TransformersMixture of ExpertsMoEScaling
Share:š•

The Core Idea

Standard transformers are "dense" — every parameter is used for every token. A 70B parameter model uses all 70B parameters to process each token.

Mixture of Experts (MoE) replaces the dense FFN layer with N "expert" FFN layers and a router that selects K of them for each token (typically K=2). With N=8 experts, only 2 are active per token — so active parameters per token is much less than total parameters.

Dense 70B model:  uses 70B params per token
MoE 8Ɨ7B model:   uses ~13B params per token (2 of 8 experts active)
                  but total parameters = ~47B (8 expert FFNs + shared layers)

Result: MoE scales capacity (total parameters) without proportionally scaling compute (active parameters per forward pass).


Router Mechanism

The router is a learned linear layer that takes the token representation and produces weights over experts:

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKRouter(nn.Module):
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        x: (batch * seq_len, d_model) — routed per token
        Returns: expert weights and indices for top-k experts
        """
        # Compute routing scores
        logits = self.gate(x)  # (batch*seq, num_experts)

        # Select top-k experts
        top_k_weights, top_k_indices = torch.topk(logits, self.top_k, dim=-1)

        # Normalize weights across selected experts with softmax
        top_k_weights = F.softmax(top_k_weights, dim=-1)

        return top_k_weights, top_k_indices  # Both: (batch*seq, top_k)

Full MoE FFN Layer

Python
class MoELayer(nn.Module):
    def __init__(self, d_model: int, d_ff: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Each expert is an independent FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff, bias=False),
                nn.SiLU(),
                nn.Linear(d_ff, d_model, bias=False),
            )
            for _ in range(num_experts)
        ])
        self.router = TopKRouter(d_model, num_experts, top_k)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)  # (batch*seq, d_model)
        n_tokens = x_flat.shape[0]

        # Get routing decisions
        weights, indices = self.router(x_flat)
        # weights, indices: (n_tokens, top_k)

        output = torch.zeros_like(x_flat)

        # Process each expert
        for expert_idx, expert in enumerate(self.experts):
            # Find which tokens and which routing slot use this expert
            token_mask, slot_mask = torch.where(indices == expert_idx)

            if len(token_mask) == 0:
                continue  # Skip expert if no tokens routed to it

            # Get the tokens for this expert
            expert_input = x_flat[token_mask]

            # Run expert
            expert_output = expert(expert_input)

            # Weight the output by the routing weight
            expert_weights = weights[token_mask, slot_mask].unsqueeze(-1)
            output[token_mask] += expert_weights * expert_output

        return output.view(batch, seq_len, d_model)

Load Balancing: Preventing Expert Collapse

Without explicit load balancing, the router tends to route all tokens to a small number of experts (expert collapse). This wastes the capacity of unused experts:

Python
class LoadBalancedMoELayer(MoELayer):
    def __init__(self, d_model, d_ff, num_experts, top_k=2, aux_loss_weight=0.01):
        super().__init__(d_model, d_ff, num_experts, top_k)
        self.aux_loss_weight = aux_loss_weight

    def compute_aux_loss(
        self,
        router_logits: torch.Tensor,  # (n_tokens, num_experts)
        top_k_indices: torch.Tensor,  # (n_tokens, top_k)
    ) -> torch.Tensor:
        """
        Switch Transformer auxiliary loss: encourages uniform expert utilization.
        Loss = num_experts Ɨ sum_i(f_i Ɨ P_i)
        where f_i = fraction of tokens routed to expert i
              P_i = mean routing probability for expert i
        """
        n_tokens = router_logits.shape[0]

        # Fraction of tokens routed to each expert
        expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
        expert_mask = expert_mask.sum(dim=1)  # (n_tokens, num_experts)
        f = expert_mask.mean(dim=0)  # (num_experts,)

        # Mean routing probability for each expert
        P = F.softmax(router_logits, dim=-1).mean(dim=0)  # (num_experts,)

        # Auxiliary loss: minimize product (high load Ɨ high prob = collapse)
        aux_loss = self.num_experts * (f * P).sum()
        return aux_loss

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)

        logits = self.router.gate(x_flat)
        weights, indices = self.router(x_flat)

        # Compute auxiliary loss for training
        aux_loss = self.compute_aux_loss(logits, indices) * self.aux_loss_weight

        output = torch.zeros_like(x_flat)
        for expert_idx, expert in enumerate(self.experts):
            token_mask, slot_mask = torch.where(indices == expert_idx)
            if len(token_mask) == 0:
                continue
            expert_output = expert(x_flat[token_mask])
            expert_weights = weights[token_mask, slot_mask].unsqueeze(-1)
            output[token_mask] += expert_weights * expert_output

        return output.view(batch, seq_len, d_model), aux_loss

Expert Parallelism for Training and Inference

In practice, each expert resides on a different GPU:

GPU 0: Expert 0, Expert 1
GPU 1: Expert 2, Expert 3
GPU 2: Expert 4, Expert 5
GPU 3: Expert 6, Expert 7

Tokens are dispatched to GPUs based on the router's decisions (all-to-all communication), processed, then gathered back. This is "expert parallelism" and is how frameworks like Megatron-LM and DeepSpeed implement MoE at scale.


Mixtral-8Ɨ7B Architecture

Mixtral is a Mistral architecture with MoE FFN:

Python
# Mixtral-8Ɨ7B configuration
mixtral_config = {
    "d_model": 4096,
    "n_layers": 32,
    "n_heads": 32,
    "n_kv_heads": 8,       # GQA same as Mistral
    "num_experts": 8,
    "top_k": 2,            # 2 of 8 experts active per token
    "ffn_hidden": 14336,   # Per expert
    "vocab_size": 32000,
}

# Parameter breakdown:
# Shared layers (embeddings, attention, norms) per layer
shared_per_layer = (
    4096 * 4096 * 4 +  # WQ, WK, WV, WO (attention)
    4096 * 2            # Two RMSNorms
)
# Expert FFN parameters per layer
expert_params = 8 * (4096 * 14336 * 3)  # 8 experts Ɨ 3 matrices (SwiGLU)

total_params = (
    32000 * 4096 +          # Embedding
    32 * (shared_per_layer + expert_params) +
    4096                    # Final norm
)

active_params = (
    total_params
    - 32 * (expert_params - 2 * 4096 * 14336 * 3 // 8)  # Subtract inactive 6 experts
)

print(f"Total parameters: {total_params/1e9:.1f}B")   # ~46.7B
print(f"Active per token: {active_params/1e9:.1f}B")  # ~12.9B

MoE vs Dense: Practical Tradeoffs

| | Dense 7B | MoE 8Ɨ7B (Mixtral) | |---|---|---| | Total parameters | 7B | 46.7B | | Active per token | 7B | 12.9B | | Compute per token | Baseline | ~2Ɨ baseline | | Memory footprint | ~14GB (bf16) | ~94GB (bf16) | | Quality | Good | Better (larger effective capacity) | | Inference speed (single GPU) | Faster (fits in memory) | Slower (may not fit) | | Multi-GPU inference | Optional | Often required |

MoE excels when:

  • You have multiple GPUs for inference (experts distributed across devices)
  • Training compute budget is fixed but you want higher model quality
  • Tasks benefit from specialization (different experts learn different domains)

MoE is harder when:

  • Single-GPU deployment is required (memory doesn't fit)
  • Low-latency inference is critical (routing overhead + expert dispatch adds latency)
  • Training stability is a priority (load balancing instability can cause training divergence)

Expert Specialization

Research shows that experts in trained MoE models do specialize. In language models, different experts tend to activate for:

  • Syntactic structures (verbs, nouns, punctuation patterns)
  • Domain-specific content (medical text, code, legal language)
  • Position-dependent patterns (beginning of sentences, end of paragraphs)

This emergent specialization is why MoE models often outperform same-compute dense models — each expert develops a more focused, refined representation for its subspace of the data distribution.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.