Multi-Head Attention

Interview Prep

HardML Engineeringmltransformernumpy

The problem

Implement multi-head self-attention from scratch. Given an input x of shape (B, L, d_model) and a number of attention heads, return the attended output of the same shape. The interviewer will follow up with: why multiple heads? what does each head "see"? why do we split, attend, then merge instead of just running attention once with full-dimensional Q/K/V?

Building block: scaled dot-product attention

import numpy as np

def softmax(x, axis):
    m = x.max(axis=axis, keepdims=True)
    e = np.exp(x - m)
    return e / e.sum(axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """Q, K, V are (..., L, d). Returns (..., L, d) and the attention weights."""
    d = Q.shape[-1]
    scores = Q @ K.swapaxes(-2, -1) / np.sqrt(d)        # (..., L, L)
    if mask is not None:
        scores = np.where(mask, scores, -1e9)
    weights = softmax(scores, axis=-1)
    out = weights @ V
    return out, weights

The scale by √d is essential. Without it, the dot products grow with d and push softmax into the saturating regime where one token's attention weight dominates and gradients vanish. With it, the variance of QK^T stays constant in d.

Pattern: split / attend / merge

Multi-head attention runs H parallel attention operations on lower-dimensional projections of the same input. The total compute is the same as one attention head at full dimension — we just partition the channels across heads instead of spending them all on one similarity score. The empirical benefit is that different heads end up specializing on different relational patterns (positional, syntactic, coreference, etc.).

Mechanically: project to Q/K/V of dimension d_model, reshape into (B, H, L, d_head) where d_head = d_model / H, run attention per (batch, head), then concat the heads back to (B, L, d_model) and apply a final output projection.

Solution

class MultiHeadAttention:
    def __init__(self, d_model: int, n_heads: int, seed: int = 0):
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        rng = np.random.default_rng(seed)
        scale = 1.0 / np.sqrt(d_model)
        self.W_q = rng.standard_normal((d_model, d_model)) * scale
        self.W_k = rng.standard_normal((d_model, d_model)) * scale
        self.W_v = rng.standard_normal((d_model, d_model)) * scale
        self.W_o = rng.standard_normal((d_model, d_model)) * scale

    def _split(self, x):
        # (B, L, d_model) -> (B, n_heads, L, d_head)
        B, L, _ = x.shape
        x = x.reshape(B, L, self.n_heads, self.d_head)
        return x.transpose(0, 2, 1, 3)

    def _merge(self, x):
        # (B, n_heads, L, d_head) -> (B, L, d_model)
        B, H, L, d = x.shape
        return x.transpose(0, 2, 1, 3).reshape(B, L, H * d)

    def forward(self, x, mask=None):
        """x: (B, L, d_model).  Self-attention."""
        Q = self._split(x @ self.W_q)
        K = self._split(x @ self.W_k)
        V = self._split(x @ self.W_v)
        out, _ = scaled_dot_product_attention(Q, K, V, mask)   # (B, H, L, d_head)
        out = self._merge(out)                                  # (B, L, d_model)
        return out @ self.W_o

Shapes trace

Shapes for B=2, L=5 (sequence length), d_model=16, n_heads=4:

x        : (2, 5, 16)
x @ W_q  : (2, 5, 16)
split    : (2, 4, 5, 4)        # heads moved up; d_head = 4
Q @ K^T  : (2, 4, 5, 5)        # one attention matrix per (batch, head)
/ sqrt(4): same
softmax  : same, rows sum to 1
@ V      : (2, 4, 5, 4)
merge    : (2, 5, 16)
@ W_o    : (2, 5, 16)

Causal masking for autoregressive decoding

In a decoder you want each position to attend only to positions ≤ itself. The mask is a lower-triangular boolean of shape (L, L); False entries get pre-softmax −∞ so they receive zero weight.

def causal_mask(L: int) -> np.ndarray:
    """True where allowed, False where blocked."""
    m = np.tril(np.ones((L, L), dtype=bool))
    return m   # broadcastable to (1, 1, L, L)

Complexity

Variations worth knowing