Multi-Head Attention
Interview Prep
The problem
Implement multi-head self-attention from scratch. Given an input
x of shape (B, L, d_model) and a number of
attention heads, return the attended output of the same shape. The
interviewer will follow up with: why multiple heads? what does
each head "see"? why do we split, attend, then merge instead of
just running attention once with full-dimensional Q/K/V?
Building block: scaled dot-product attention
import numpy as np
def softmax(x, axis):
m = x.max(axis=axis, keepdims=True)
e = np.exp(x - m)
return e / e.sum(axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""Q, K, V are (..., L, d). Returns (..., L, d) and the attention weights."""
d = Q.shape[-1]
scores = Q @ K.swapaxes(-2, -1) / np.sqrt(d) # (..., L, L)
if mask is not None:
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
out = weights @ V
return out, weights
The scale by √d is essential. Without it, the
dot products grow with d and push softmax into the
saturating regime where one token's attention weight dominates and
gradients vanish. With it, the variance of QK^T stays
constant in d.
Pattern: split / attend / merge
Multi-head attention runs H parallel attention
operations on lower-dimensional projections of the same input. The
total compute is the same as one attention head at full dimension —
we just partition the channels across heads instead of
spending them all on one similarity score. The empirical benefit
is that different heads end up specializing on different relational
patterns (positional, syntactic, coreference, etc.).
Mechanically: project to Q/K/V of dimension d_model,
reshape into (B, H, L, d_head) where
d_head = d_model / H, run attention per
(batch, head), then concat the heads back to
(B, L, d_model) and apply a final output projection.
Solution
class MultiHeadAttention:
def __init__(self, d_model: int, n_heads: int, seed: int = 0):
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
rng = np.random.default_rng(seed)
scale = 1.0 / np.sqrt(d_model)
self.W_q = rng.standard_normal((d_model, d_model)) * scale
self.W_k = rng.standard_normal((d_model, d_model)) * scale
self.W_v = rng.standard_normal((d_model, d_model)) * scale
self.W_o = rng.standard_normal((d_model, d_model)) * scale
def _split(self, x):
# (B, L, d_model) -> (B, n_heads, L, d_head)
B, L, _ = x.shape
x = x.reshape(B, L, self.n_heads, self.d_head)
return x.transpose(0, 2, 1, 3)
def _merge(self, x):
# (B, n_heads, L, d_head) -> (B, L, d_model)
B, H, L, d = x.shape
return x.transpose(0, 2, 1, 3).reshape(B, L, H * d)
def forward(self, x, mask=None):
"""x: (B, L, d_model). Self-attention."""
Q = self._split(x @ self.W_q)
K = self._split(x @ self.W_k)
V = self._split(x @ self.W_v)
out, _ = scaled_dot_product_attention(Q, K, V, mask) # (B, H, L, d_head)
out = self._merge(out) # (B, L, d_model)
return out @ self.W_o Shapes trace
Shapes for B=2, L=5 (sequence length), d_model=16, n_heads=4:
x : (2, 5, 16)
x @ W_q : (2, 5, 16)
split : (2, 4, 5, 4) # heads moved up; d_head = 4
Q @ K^T : (2, 4, 5, 5) # one attention matrix per (batch, head)
/ sqrt(4): same
softmax : same, rows sum to 1
@ V : (2, 4, 5, 4)
merge : (2, 5, 16)
@ W_o : (2, 5, 16) Causal masking for autoregressive decoding
In a decoder you want each position to attend only to positions ≤
itself. The mask is a lower-triangular boolean of shape
(L, L); False entries get pre-softmax
−∞ so they receive zero weight.
def causal_mask(L: int) -> np.ndarray:
"""True where allowed, False where blocked."""
m = np.tril(np.ones((L, L), dtype=bool))
return m # broadcastable to (1, 1, L, L) Complexity
- Time:
O(B · H · L² · d_head) = O(B · L² · d_model). TheL²term is the well-known scaling limit of transformers. - Space:
O(B · H · L²)for the attention matrices.
Variations worth knowing
- Cross-attention: queries come from one sequence
(e.g., decoder hidden states), keys/values from another (encoder
outputs). Same code, different inputs to
QversusK, V. - Grouped-query attention (GQA): share
KandVprojections across a group of heads. Reduces KV-cache memory at inference; ubiquitous in modern LLMs (LLaMA-2, Mistral). - FlashAttention: the same math but with a tiled,
I/O-aware implementation that avoids materialising the
L²matrix in HBM. Linear memory, big speedup, exact. - Linear / kernelised attention: approximate
softmax(QK^T)V using a feature map so the cost drops to
O(L · d²). Trades accuracy for very long sequences. - Rotary positional embeddings (RoPE): inject position by rotating Q and K in pairs of dimensions before the dot product. Replaces additive positional embeddings; standard in modern LLMs.