Stable Softmax

Interview Prep

Warm-upML Engineeringmlnumerics

The problem

Implement so that it returns the correct probabilities even when the inputs are large (say ). The naive direct translation overflows. Fix it.

Pattern: shift invariance + log-sum-exp

Softmax has the property that it is invariant under additive shifts: for any scalar . (Each picks up the same factor , which cancels in the ratio.) So you are free to subtract any you like before exponentiating, and the standard trick is : now the largest exponent is 0, the rest are negative, and nothing overflows.

This is the canonical "numerical stability" question asked of every ML engineering candidate. The interviewer is looking for: (i) you recognize the overflow problem, (ii) you know the max-subtraction trick, (iii) you can write the batched version that works along an arbitrary axis.

Naive: overflows for large inputs

import numpy as np

def softmax_naive(x: np.ndarray) -> np.ndarray:
    """Mathematically correct, numerically dangerous."""
    e = np.exp(x)
    return e / e.sum()

For float64, already overflows. Modern transformer logits can be hundreds in magnitude before the final layer; this code returns NaNs on a basic forward pass.

Stable: subtract the max

import numpy as np

def softmax(x: np.ndarray) -> np.ndarray:
    """Stable softmax: subtract the max before exponentiating."""
    shifted = x - np.max(x)
    e = np.exp(shifted)
    return e / e.sum()

After the shift, the largest exponent is and the others are in . The denominator is at least 1 and at most . No overflow, full precision.

>>> x = np.array([1000., 1001., 1002.])
>>> softmax_naive(x)
RuntimeWarning: overflow encountered in exp
array([nan, nan, nan])
>>> softmax(x)
array([0.09003057, 0.24472847, 0.66524096])

Batched: along an axis

Real ML code rarely operates on a single 1D vector — you're typically given a matrix of logits where each row is a separate distribution. axis=-1 says "softmax each row"; keepdims=True keeps the shape broadcastable.

def softmax_batch(X: np.ndarray, axis: int = -1) -> np.ndarray:
    """Stable softmax along an arbitrary axis (rows of a batched logit matrix)."""
    shifted = X - np.max(X, axis=axis, keepdims=True)
    e = np.exp(shifted)
    return e / e.sum(axis=axis, keepdims=True)

Bonus: log-softmax

For cross-entropy loss, you really want log(softmax(x)), not softmax(x) followed by a log. Combining them via the log-sum-exp identity avoids one round of exp+log and is even more numerically friendly:

def log_softmax(x: np.ndarray) -> np.ndarray:
    """log-softmax via log-sum-exp. Even more numerically friendly for
    downstream cross-entropy."""
    m = np.max(x)
    return x - m - np.log(np.exp(x - m).sum())

PyTorch's F.cross_entropy(logits, target) uses exactly this combined form internally — that's why it's preferred over F.nll_loss(softmax(logits).log(), target), which is mathematically equivalent but loses precision.

Complexity

Variations worth knowing