Layer Normalization
Interview Prep
The problem
Implement layer normalization (forward and backward) in NumPy.
Given an input x of shape (…, D),
normalize across the last (feature) axis, then apply a learned
affine transform γ ⊙ x̂ + β. Return the output, and
in the backward pass return gradients with respect to
x, γ, and β.
Pattern: per-token statistics
Layer norm computes the mean and variance independently for each "example" (or token in the transformer setting) across the feature dimension. The result has zero mean and unit variance per token, which lets the model learn at much higher effective learning rates and removes the dependence on batch composition.
Forward
import numpy as np
class LayerNorm:
"""Layer normalization over the last axis (the feature axis)."""
def __init__(self, d: int, eps: float = 1e-5):
self.gamma = np.ones(d)
self.beta = np.zeros(d)
self.eps = eps
self.cache = None
def forward(self, x: np.ndarray) -> np.ndarray:
# x has shape (..., d). Normalize across the last axis.
mu = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
std = np.sqrt(var + self.eps)
x_hat = (x - mu) / std
out = self.gamma * x_hat + self.beta
self.cache = (x_hat, std)
return out Backward
The chain rule through layer norm has three terms because the mean and variance both depend on every input feature. The clean closed form is:
dx = (1 / (D · σ)) · [D · dx̂ − Σ dx̂ − x̂ · Σ (dx̂ · x̂)]
Derivation: write x̂ = (x − μ)/σ, then use
dx̂/dx = 1/σ · (I − 1/D · 11ᵀ − x̂ · x̂ᵀ / D).
Multiplying by the upstream gradient gives the three-term form
above. The two Σs are over the feature axis.
def backward(self, dout: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Return grads w.r.t. x, gamma, beta."""
x_hat, std = self.cache
d = x_hat.shape[-1]
# Param grads (sum over all axes except the feature axis)
sum_axes = tuple(range(dout.ndim - 1))
dgamma = (dout * x_hat).sum(axis=sum_axes)
dbeta = dout.sum(axis=sum_axes)
# Backprop through normalization
dxhat = dout * self.gamma
# dx = (1 / (d * std)) * (d * dxhat - sum(dxhat) - x_hat * sum(dxhat * x_hat))
s1 = dxhat.sum(axis=-1, keepdims=True)
s2 = (dxhat * x_hat).sum(axis=-1, keepdims=True)
dx = (1.0 / (d * std)) * (d * dxhat - s1 - x_hat * s2)
return dx, dgamma, dbeta Trace (forward only)
x = [2.0, 4.0, 6.0, 8.0] (single feature vector, d=4)
mu = 5.0
var = 5.0
std = sqrt(5 + 1e-5) ≈ 2.236
x_hat = [(2-5)/2.236, (4-5)/2.236, (6-5)/2.236, (8-5)/2.236]
≈ [-1.342, -0.447, 0.447, 1.342]
With gamma=[1,1,1,1], beta=[0,0,0,0]:
out ≈ x_hat ≈ [-1.342, -0.447, 0.447, 1.342]
mean(out) ≈ 0, var(out) ≈ 1 ✓ Layer norm vs batch norm
# Suppose x has shape (B, S, D) — batch, sequence, features.
# Batch norm normalizes across the BATCH axis (and seq) per feature:
# mu, var have shape (1, 1, D); statistics shared across all examples.
# - Couples examples together. Different at train vs eval (running stats).
# - Breaks for batch=1 and for variable-length seqs.
# Layer norm normalizes across the FEATURE axis per (example, position):
# mu, var have shape (B, S, 1); independent per token.
# - No coupling. Same behavior at train and eval.
# - Trivially works for any batch size; standard in transformers. The reason every transformer uses layer norm rather than batch norm is here in a nutshell: independence across examples, no train/eval skew, no batch-size sensitivity. Batch norm is still the default for CNNs on images, where batch-statistics regularization is part of the appeal.
Complexity
- Time:
O(N · D)for both forward and backward (N = elements outside the feature axis). - Space:
O(N · D)for cachingx̂andσ.
Variations worth knowing
- RMSNorm: drop the mean subtraction, divide only
by the root-mean-square of
x. Slightly cheaper and empirically just as good in modern LLMs (used in LLaMA, T5). - Group norm: a middle ground between layer and batch norm — partition channels into groups and normalize within each. Strong for small batches in vision.
- Pre-norm vs post-norm in transformers: apply LN before each sublayer (and after the final block) gives much more stable training than applying it after the residual add, which is why all modern decoders are pre-norm.
- Weight standardization: the dual idea — normalize weights instead of activations. Pairs well with group norm for very-large-model training.