Implement BPE Tokenization
Byte Pair Encoding step by step: initialize character vocabulary, merge most frequent pairs iteratively, apply merges to new text, and complete Python implementation.
What is BPE and Why Does It Matter
Byte Pair Encoding (BPE) is the tokenization algorithm behind GPT-2, GPT-3, GPT-4, LLaMA, and many other LLMs. Understanding BPE is not just interview prep — it's fundamental knowledge for anyone debugging why an LLM fails on rare words, code snippets, or non-English text.
BPE starts with a character-level vocabulary and iteratively merges the most frequent adjacent pair of tokens into a new token. After k merges, you have a vocabulary of characters plus learned subword units.
The Algorithm, Step by Step
Step 1: Start with a word-segmented corpus. Represent each word as characters separated by spaces, with a special end-of-word marker.
Step 2: Count the frequency of every adjacent token pair across the entire corpus.
Step 3: Merge the most frequent pair into a new single token.
Step 4: Repeat steps 2-3 for k iterations (where k = number of merges = vocabulary size beyond characters).
from collections import Counter, defaultdict
from typing import Iterator
def get_vocab_from_corpus(corpus: list[str]) -> dict[str, int]:
"""
Build initial word frequency dictionary.
Each word is represented as space-separated characters + </w> end marker.
Example: "low" → "l o w </w>"
The </w> marker lets us distinguish "low" from "lower" at the character level.
After merges, "low </w>" becomes one unit while "low e r </w>" keeps them separate.
"""
vocab: dict[str, int] = defaultdict(int)
for text in corpus:
for word in text.strip().split():
# Convert word to space-separated chars with end marker
char_word = " ".join(list(word)) + " </w>"
vocab[char_word] += 1
return dict(vocab)
def get_pair_frequencies(vocab: dict[str, int]) -> Counter:
"""
Count frequency of every adjacent pair in the vocab.
vocab: {word_as_chars: frequency}
Returns Counter of {(token_a, token_b): total_frequency}
Time: O(V * max_word_len) where V = vocab size
"""
pairs: Counter = Counter()
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i + 1])] += freq
return pairs
def merge_vocab(vocab: dict[str, int], pair: tuple[str, str]) -> dict[str, int]:
"""
Merge all occurrences of `pair` in every word in the vocab.
Example: merge ('e', 's') in vocab
"e s t </w>" becomes "es t </w>"
Uses string replacement on the space-separated representation.
"""
bigram = " ".join(pair)
replacement = "".join(pair) # merged form (no space between)
new_vocab: dict[str, int] = {}
for word, freq in vocab.items():
# Replace bigram with merged token (using word boundaries)
new_word = word.replace(bigram, replacement)
new_vocab[new_word] = freq
return new_vocabFull BPE Training Implementation
def train_bpe(
corpus: list[str],
num_merges: int,
verbose: bool = False,
) -> tuple[dict[str, int], list[tuple[str, str]]]:
"""
Train BPE tokenizer on a corpus.
Returns:
vocab: final vocabulary {word_repr: freq}
merges: ordered list of merges performed [(token_a, token_b), ...]
The merges list is the "model" — to tokenize new text, apply
these merges in order.
Time: O(num_merges * V * max_word_len)
Space: O(V * max_word_len)
Parameters:
corpus: list of training sentences
num_merges: number of merge operations = extra vocab size beyond chars
"""
vocab = get_vocab_from_corpus(corpus)
merges: list[tuple[str, str]] = []
if verbose:
print(f"Initial vocab size: {len(vocab)} unique words")
print(f"Initial representation sample:")
for word in list(vocab.keys())[:3]:
print(f" '{word}' (freq={vocab[word]})")
print()
for merge_idx in range(num_merges):
pairs = get_pair_frequencies(vocab)
if not pairs:
if verbose:
print(f"No more pairs at merge {merge_idx}. Stopping.")
break
# Find the most frequent pair (ties broken by pair string for determinism)
best_pair = max(pairs, key=lambda p: (pairs[p], p))
best_freq = pairs[best_pair]
if verbose:
merged_token = "".join(best_pair)
print(f"Merge {merge_idx + 1:3d}: {best_pair} → '{merged_token}' (freq={best_freq})")
vocab = merge_vocab(vocab, best_pair)
merges.append(best_pair)
return vocab, merges
# Test on a toy corpus
toy_corpus = [
"low lower lowest",
"new newer newest",
"old older oldest",
"low new low new low",
"the newest lowest old newer",
]
print("=== BPE Training ===")
trained_vocab, merge_list = train_bpe(toy_corpus, num_merges=15, verbose=True)
print(f"\nFinal vocab sample:")
for word, freq in sorted(trained_vocab.items(), key=lambda x: -x[1])[:10]:
print(f" '{word}': {freq}")
print(f"\nTotal merges learned: {len(merge_list)}")Applying BPE Merges to New Text
After training, encoding new text means applying the learned merges in order.
def apply_bpe_to_word(word: str, merges: list[tuple[str, str]]) -> list[str]:
"""
Apply trained BPE merges to a single word.
Steps:
1. Start with character-level representation + </w>
2. Apply each merge in training order
3. Return the resulting tokens
Time: O(num_merges * word_len) per word
"""
# Start: "lower" → ["l", "o", "w", "e", "r", "</w>"]
symbols = list(word) + ["</w>"]
for merge_pair in merges:
# Try to apply this merge anywhere in the symbol list
new_symbols = []
i = 0
while i < len(symbols):
# Check if merge_pair[0] and merge_pair[1] are adjacent at position i
if (
i < len(symbols) - 1
and symbols[i] == merge_pair[0]
and symbols[i + 1] == merge_pair[1]
):
new_symbols.append(merge_pair[0] + merge_pair[1]) # merge!
i += 2 # skip both
else:
new_symbols.append(symbols[i])
i += 1
symbols = new_symbols
return symbols
def encode_bpe(
text: str,
merges: list[tuple[str, str]],
include_end_marker: bool = False,
) -> list[str]:
"""
Encode text using trained BPE merges.
Returns list of BPE tokens.
Tokens ending with </w> mark end-of-word boundaries.
"""
all_tokens = []
for word in text.strip().split():
word_tokens = apply_bpe_to_word(word, merges)
if not include_end_marker:
# Remove </w> markers for cleaner display
word_tokens = [t.replace("</w>", "") for t in word_tokens if t != "</w>"]
all_tokens.extend(word_tokens)
return all_tokens
# Test encoding
test_words = ["low", "lower", "lowest", "newer", "oldest", "unknown"]
print("\n=== BPE Encoding ===")
for word in test_words:
tokens = apply_bpe_to_word(word, merge_list)
print(f" '{word}' → {tokens}")
# Encode a full sentence
sentence = "the newest lowest approach"
tokens = encode_bpe(sentence, merge_list)
print(f"\nEncoding: '{sentence}'")
print(f"Tokens: {tokens}")Building the BPE Vocabulary
def build_bpe_vocabulary(
initial_chars: set[str],
merges: list[tuple[str, str]],
) -> dict[str, int]:
"""
Build the final token-to-id mapping.
Starts with special tokens + all characters, then adds each merged token.
The order matters: tokens learned later have higher IDs.
"""
vocab: dict[str, int] = {}
# Special tokens first
special = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
for i, token in enumerate(special):
vocab[token] = i
# All individual characters
for char in sorted(initial_chars):
if char not in vocab:
vocab[char] = len(vocab)
# Each merged token
for pair in merges:
merged = pair[0] + pair[1]
if merged not in vocab:
vocab[merged] = len(vocab)
return vocab
def get_initial_chars(corpus: list[str]) -> set[str]:
"""Get all unique characters in the corpus + end marker."""
chars = set()
for text in corpus:
for word in text.split():
chars.update(word)
chars.add("</w>")
return chars
initial_chars = get_initial_chars(toy_corpus)
bpe_vocab = build_bpe_vocabulary(initial_chars, merge_list)
print(f"\n=== BPE Vocabulary ===")
print(f"Vocabulary size: {len(bpe_vocab)}")
print(f"Sample entries:")
for token, token_id in list(bpe_vocab.items())[:15]:
print(f" '{token}': {token_id}")Complete BPE Tokenizer Class
class BPETokenizer:
"""
A minimal but complete BPE tokenizer.
This implements the same algorithm used by GPT-2 (minus byte-level encoding).
HuggingFace's tokenizers library does the same thing in Rust for speed.
"""
def __init__(self):
self.merges: list[tuple[str, str]] = []
self.vocab: dict[str, int] = {}
self.id_to_token: dict[int, str] = {}
self._merge_set: set[tuple[str, str]] = set() # for O(1) lookup
def train(self, corpus: list[str], num_merges: int = 50) -> "BPETokenizer":
"""Train BPE on corpus."""
_, self.merges = train_bpe(corpus, num_merges, verbose=False)
self._merge_set = set(self.merges)
initial_chars = get_initial_chars(corpus)
self.vocab = build_bpe_vocabulary(initial_chars, self.merges)
self.id_to_token = {v: k for k, v in self.vocab.items()}
return self
def tokenize(self, text: str) -> list[str]:
"""Convert text to list of BPE token strings."""
tokens = []
for word in text.strip().split():
word_tokens = apply_bpe_to_word(word, self.merges)
tokens.extend(word_tokens)
return tokens
def encode(self, text: str) -> list[int]:
"""Convert text to list of token IDs."""
unk_id = self.vocab.get("[UNK]", 1)
tokens = self.tokenize(text)
return [self.vocab.get(t, unk_id) for t in tokens]
def decode(self, ids: list[int]) -> str:
"""Convert token IDs back to text (approximate)."""
tokens = [self.id_to_token.get(i, "[UNK]") for i in ids]
# Reconstruct words by joining tokens and removing </w>
text = " ".join(tokens)
text = text.replace(" </w>", "").replace("</w>", "")
return text
@property
def vocab_size(self) -> int:
return len(self.vocab)
def __repr__(self) -> str:
return f"BPETokenizer(vocab_size={self.vocab_size}, num_merges={len(self.merges)})"
# Full end-to-end test
tokenizer = BPETokenizer()
tokenizer.train(toy_corpus, num_merges=20)
print(f"\n{tokenizer}")
test_texts = [
"low newer oldest",
"the lowest new old",
]
for text in test_texts:
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
print(f"\nText: '{text}'")
print(f" Tokens: {tokenizer.tokenize(text)}")
print(f" IDs: {ids}")
print(f" Decoded: '{decoded}'")Why BPE Handles Unknown Words
# The key insight: BPE always falls back to characters
# "zephyr" was never in training, but its characters were
def demonstrate_oov_handling(tokenizer: BPETokenizer) -> None:
"""
Show how BPE handles out-of-vocabulary words by character fallback.
Word-level: "zephyr" → [UNK] (total information loss)
BPE: "zephyr" → ["z", "e", "p", "h", "y", "r", "</w>"]
→ can still use character-level patterns
"""
oov_words = ["zephyr", "transformers", "gpt4", "cryptocurrency"]
print("\n=== OOV Handling ===")
for word in oov_words:
tokens = apply_bpe_to_word(word, tokenizer.merges)
print(f" '{word}' → {tokens}")
# Characters that were in training vocab can be encoded
# Those that weren't → [UNK] per character
demonstrate_oov_handling(tokenizer)Interview Summary
The BPE algorithm in 4 lines:
- Start with characters as tokens
- Count all adjacent token pairs in the corpus
- Merge the most frequent pair into a new token
- Repeat k times
Why BPE beats word-level:
- Handles rare words through subword decomposition
- Vocabulary is bounded (you choose k merges)
- Better cross-lingual transfer — common subwords appear in multiple languages
Why GPT uses byte-level BPE: GPT-2 uses BPE on UTF-8 bytes, not Unicode characters. This means the vocabulary covers every possible byte sequence — no token is truly OOV. The 256 bytes form the initial vocabulary, and merges are learned from there.
Time complexity of encoding: O(m × w) where m = number of merges and w = max word length. In practice, words converge quickly and most merges after the first few passes through a word have no effect.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.