Cross-Entropy + Softmax Gradient
Interview Prep
The problem
Given pre-softmax logits z and a class index
y, compute (1) the softmax + cross-entropy loss in a
numerically stable way, and (2) its gradient with respect to
z. The famous result is
∂L/∂z = p − one_hot(y).
Why fuse softmax and cross-entropy?
Implementing softmax and cross-entropy separately is a numerical
disaster waiting to happen: you exponentiate, possibly overflow,
then divide, then take a log of something that may have underflowed.
All those operations are unnecessary if you reformulate the loss as
−log_softmax(z)[y] = −z_y + logsumexp(z). No
intermediate small probabilities, no log of tiny numbers.
import numpy as np
def softmax_naive(z):
e = np.exp(z) # overflows for large z
return e / e.sum()
def cross_entropy_naive(p, y_idx):
return -np.log(p[y_idx]) # if p[y_idx] underflowed to 0, this is -inf Naive: softmax then -log(p[y]).
- If a logit is huge, exp overflows -> p has inf/inf -> NaN.
- If p[y] underflows to 0, log goes to -inf.
Stable: combine into log-softmax = z - logsumexp(z), then loss = -log_p[y].
- logsumexp subtracts the max first; all exp arguments <= 0; safe.
- log_p is computed directly; never need a tiny p value.
- p is only formed for the GRADIENT, where we don't lose precision
from taking a log. Pattern: derive the gradient through the whole stack at once
Rather than backprop softmax (Jacobian is an n × n
matrix!) followed by cross-entropy, do them jointly. The Jacobian
times the cross-entropy gradient collapses to a vector difference.
This is the most important shortcut in the entire backprop
literature; it's why every modern framework has a single
softmax_cross_entropy_with_logits primitive.
Derivation
Let z be pre-softmax logits, p_i = softmax(z)_i = exp(z_i)/Σ_j exp(z_j).
Let y be the true class index. Cross entropy: L = -log p_y.
Compute dL/dz_k for arbitrary k:
L = -z_y + log Σ_j exp(z_j)
dL/dz_k = -[k == y] + exp(z_k) / Σ_j exp(z_j)
= -[k == y] + p_k
= p_k - 1{k == y}
i.e. dL/dz = p - one_hot(y).
That's it. All the complexity of softmax + log inside the loss cancels
to "predicted minus target" — the same form as linear regression's MSE
gradient.
Sanity check: at the optimum the model puts all mass on the true
class — p_y = 1, others 0. Then
grad = p − one_hot(y) = 0. ✓
Fused implementation
def softmax_cross_entropy(z: np.ndarray, y_idx: int) -> tuple[float, np.ndarray]:
"""Fused forward + backward. z is the pre-softmax logit vector."""
# 1) Stable softmax via log-sum-exp.
m = z.max()
log_sum_exp = m + np.log(np.exp(z - m).sum())
log_p = z - log_sum_exp # log probabilities, never explicitly exponentiated for the loss
loss = -log_p[y_idx] # cross entropy of one-hot label
# 2) Gradient: dL/dz = p - one_hot(y).
p = np.exp(log_p)
grad = p.copy()
grad[y_idx] -= 1.0
return float(loss), grad Complexity
- Time:
O(C)forCclasses — one pass for logsumexp, one for the gradient. - Space:
O(C).
Batched and label-smoothed variants
In practice z is (batch, C) and the label
is a vector of indices. The gradient becomes
(softmax(z) − Y) / batch where Y is a
one-hot matrix. With label smoothing, replace the
hard one-hot with (1 − ε)·one_hot + ε/C; the gradient
is still p − Y but with the smoothed Y.
Variations worth knowing
- Binary cross entropy (BCE): the 2-class case
collapses to the logistic loss. Gradient w.r.t. the single logit
is
σ(z) − y. Same shape: predicted minus target. - KL divergence loss:
D_KL(Y || p)between two distributions. Gradient is stillp − Ywhen paired with softmax. The entropy ofYonly adds a constant. - Negative log likelihood (NLL): what PyTorch calls
it. Same thing — except
nn.NLLLossexpects log-probs as input;nn.CrossEntropyLossexpects raw logits and does the fused stable form internally. - Focal loss: reweight the per-example cross
entropy by
(1 − p_y)^γto focus on hard examples. Identical gradient skeleton with the focal weight folded in.