Online Mean and Variance (Welford)

Interview Prep

Warm-upML Engineeringmlstreamingnumerics

The problem

Compute the mean and variance of a stream of values in one pass, without storing the data and without numerical instability. Each incoming value updates the running estimates in O(1).

Why not the textbook formula?

The "obvious" online formula is to keep Σx and Σx², then compute var = Σx²/n − (Σx/n)². This works on paper. In floating point it's a disaster: when the mean is large relative to the spread (e.g., temperatures around 1000 with variance 0.01), you subtract two huge nearly-equal numbers and lose every digit of precision. This is catastrophic cancellation.

def mean_var_two_pass(xs):
    n = len(xs)
    s1 = sum(xs); s2 = sum(x*x for x in xs)
    mean = s1 / n
    var  = s2 / n - mean * mean       # numerically unstable!
    return mean, var

Pattern: incremental updates, no cancellation

Welford's recurrence updates the mean and a quantity M2 (the running sum of squared deviations) directly, never forming large intermediate sums. The trick is the asymmetric pair of deltas: delta uses the old mean, delta2 uses the new mean. The product telescopes the contribution to the variance exactly.

Solution

class RunningStats:
    """Welford's online algorithm for mean and variance."""
    def __init__(self):
        self.n = 0
        self.mean = 0.0
        self.M2   = 0.0                # sum of squared deviations from current mean

    def update(self, x: float) -> None:
        self.n += 1
        delta  = x - self.mean
        self.mean += delta / self.n
        delta2 = x - self.mean         # NOTE: uses the NEW mean
        self.M2 += delta * delta2

    @property
    def variance(self) -> float:
        return self.M2 / self.n if self.n else 0.0   # population variance

    @property
    def sample_variance(self) -> float:
        return self.M2 / (self.n - 1) if self.n > 1 else 0.0

Trace

xs = [4, 7, 13, 16]

n=1: x=4   delta=4-0=4    mean=0+4/1=4     delta2=4-4=0    M2=4*0=0
n=2: x=7   delta=7-4=3    mean=4+3/2=5.5   delta2=7-5.5=1.5 M2=0+3*1.5=4.5
n=3: x=13  delta=13-5.5=7.5 mean=5.5+7.5/3=8.0 delta2=13-8=5 M2=4.5+7.5*5=42
n=4: x=16  delta=16-8=8   mean=8+8/4=10    delta2=16-10=6  M2=42+8*6=90

mean = 10
variance = M2/n = 90/4 = 22.5

Sanity check: mean of [4, 7, 13, 16] is 10. Deviations are [−6, −3, 3, 6], squared and summed: 36 + 9 + 9 + 36 = 90. Population variance = 90/4 = 22.5. ✓

Reference: naive two-pass

def mean_var_naive(xs: list[float]) -> tuple[float, float]:
    n = len(xs)
    mean = sum(xs) / n
    var  = sum((x - mean) ** 2 for x in xs) / n
    return mean, var

Complexity

Variations worth knowing