Layer Normalization

Interview Prep

StandardML Engineeringmlnumpytransformer

The problem

Implement layer normalization (forward and backward) in NumPy. Given an input x of shape (…, D), normalize across the last (feature) axis, then apply a learned affine transform γ ⊙ x̂ + β. Return the output, and in the backward pass return gradients with respect to x, γ, and β.

Pattern: per-token statistics

Layer norm computes the mean and variance independently for each "example" (or token in the transformer setting) across the feature dimension. The result has zero mean and unit variance per token, which lets the model learn at much higher effective learning rates and removes the dependence on batch composition.

Forward

import numpy as np

class LayerNorm:
    """Layer normalization over the last axis (the feature axis)."""
    def __init__(self, d: int, eps: float = 1e-5):
        self.gamma = np.ones(d)
        self.beta  = np.zeros(d)
        self.eps   = eps
        self.cache = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        # x has shape (..., d). Normalize across the last axis.
        mu  = x.mean(axis=-1, keepdims=True)
        var = x.var(axis=-1, keepdims=True)
        std = np.sqrt(var + self.eps)
        x_hat = (x - mu) / std
        out = self.gamma * x_hat + self.beta
        self.cache = (x_hat, std)
        return out

Backward

The chain rule through layer norm has three terms because the mean and variance both depend on every input feature. The clean closed form is:

dx = (1 / (D · σ)) · [D · dx̂ − Σ dx̂ − x̂ · Σ (dx̂ · x̂)]

Derivation: write x̂ = (x − μ)/σ, then use dx̂/dx = 1/σ · (I − 1/D · 11ᵀ − x̂ · x̂ᵀ / D). Multiplying by the upstream gradient gives the three-term form above. The two Σs are over the feature axis.

    def backward(self, dout: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Return grads w.r.t. x, gamma, beta."""
        x_hat, std = self.cache
        d = x_hat.shape[-1]

        # Param grads (sum over all axes except the feature axis)
        sum_axes = tuple(range(dout.ndim - 1))
        dgamma = (dout * x_hat).sum(axis=sum_axes)
        dbeta  = dout.sum(axis=sum_axes)

        # Backprop through normalization
        dxhat = dout * self.gamma
        # dx = (1 / (d * std)) * (d * dxhat - sum(dxhat) - x_hat * sum(dxhat * x_hat))
        s1 = dxhat.sum(axis=-1, keepdims=True)
        s2 = (dxhat * x_hat).sum(axis=-1, keepdims=True)
        dx = (1.0 / (d * std)) * (d * dxhat - s1 - x_hat * s2)
        return dx, dgamma, dbeta

Trace (forward only)

x   = [2.0, 4.0, 6.0, 8.0]   (single feature vector, d=4)
mu  = 5.0
var = 5.0
std = sqrt(5 + 1e-5) ≈ 2.236
x_hat = [(2-5)/2.236, (4-5)/2.236, (6-5)/2.236, (8-5)/2.236]
      ≈ [-1.342, -0.447, 0.447, 1.342]

With gamma=[1,1,1,1], beta=[0,0,0,0]:
out ≈ x_hat ≈ [-1.342, -0.447, 0.447, 1.342]

mean(out) ≈ 0,  var(out) ≈ 1   ✓

Layer norm vs batch norm

# Suppose x has shape (B, S, D) — batch, sequence, features.

# Batch norm normalizes across the BATCH axis (and seq) per feature:
#   mu, var have shape (1, 1, D); statistics shared across all examples.
#   - Couples examples together.  Different at train vs eval (running stats).
#   - Breaks for batch=1 and for variable-length seqs.

# Layer norm normalizes across the FEATURE axis per (example, position):
#   mu, var have shape (B, S, 1); independent per token.
#   - No coupling.  Same behavior at train and eval.
#   - Trivially works for any batch size; standard in transformers.

The reason every transformer uses layer norm rather than batch norm is here in a nutshell: independence across examples, no train/eval skew, no batch-size sensitivity. Batch norm is still the default for CNNs on images, where batch-statistics regularization is part of the appeal.

Complexity

Variations worth knowing