Stable Softmax
Interview Prep
The problem
Implement so that it returns the correct probabilities even when the inputs are large (say ). The naive direct translation overflows. Fix it.
Pattern: shift invariance + log-sum-exp
Softmax has the property that it is invariant under additive shifts: for any scalar . (Each picks up the same factor , which cancels in the ratio.) So you are free to subtract any you like before exponentiating, and the standard trick is : now the largest exponent is 0, the rest are negative, and nothing overflows.
This is the canonical "numerical stability" question asked of every ML engineering candidate. The interviewer is looking for: (i) you recognize the overflow problem, (ii) you know the max-subtraction trick, (iii) you can write the batched version that works along an arbitrary axis.
Naive: overflows for large inputs
import numpy as np
def softmax_naive(x: np.ndarray) -> np.ndarray:
"""Mathematically correct, numerically dangerous."""
e = np.exp(x)
return e / e.sum()
For float64, already overflows. Modern
transformer logits can be hundreds in magnitude before the final
layer; this code returns NaNs on a basic forward pass.
Stable: subtract the max
import numpy as np
def softmax(x: np.ndarray) -> np.ndarray:
"""Stable softmax: subtract the max before exponentiating."""
shifted = x - np.max(x)
e = np.exp(shifted)
return e / e.sum() After the shift, the largest exponent is and the others are in . The denominator is at least 1 and at most . No overflow, full precision.
>>> x = np.array([1000., 1001., 1002.])
>>> softmax_naive(x)
RuntimeWarning: overflow encountered in exp
array([nan, nan, nan])
>>> softmax(x)
array([0.09003057, 0.24472847, 0.66524096]) Batched: along an axis
Real ML code rarely operates on a single 1D vector — you're
typically given a matrix of logits where each row is a separate
distribution. axis=-1 says "softmax each row";
keepdims=True keeps the shape broadcastable.
def softmax_batch(X: np.ndarray, axis: int = -1) -> np.ndarray:
"""Stable softmax along an arbitrary axis (rows of a batched logit matrix)."""
shifted = X - np.max(X, axis=axis, keepdims=True)
e = np.exp(shifted)
return e / e.sum(axis=axis, keepdims=True) Bonus: log-softmax
For cross-entropy loss, you really want log(softmax(x)),
not softmax(x) followed by a log. Combining them via
the log-sum-exp identity avoids one round of exp+log and is even
more numerically friendly:
def log_softmax(x: np.ndarray) -> np.ndarray:
"""log-softmax via log-sum-exp. Even more numerically friendly for
downstream cross-entropy."""
m = np.max(x)
return x - m - np.log(np.exp(x - m).sum())
PyTorch's F.cross_entropy(logits, target) uses
exactly this combined form internally — that's why it's preferred
over F.nll_loss(softmax(logits).log(), target), which
is mathematically equivalent but loses precision.
Complexity
- Time:
O(n)for length-n input. - Space:
O(n)for the output. - Passes: two — one to find max, one to exponentiate and normalize. There's no single-pass version; the max must be known before the sum can be stable.
Variations worth knowing
- Temperature scaling: . sharpens toward argmax; flattens to uniform. The stability trick still applies — subtract the max after dividing by .
- Gumbel-softmax: sample from a categorical distribution differentiably by adding Gumbel noise to logits and taking a soft argmax.
- Sparsemax: a piecewise-linear alternative to
softmax that produces exactly-zero entries. Useful for
interpretability in attention;
O(n \\log n)via a sort. - Hierarchical softmax: for very large output
vocabularies (millions of words), tree-based softmax reduces
cost from
O(V)toO(\\log V)per token.