K-Means From Scratch
Interview Prep
The problem
Implement k-means clustering in pure NumPy. Given n
points in d dimensions and a number of clusters
k, return the cluster centroids and the per-point
labels. The objective is to minimize the sum of squared distances
from each point to its assigned centroid (the "inertia").
Lloyd's algorithm in one sentence
Alternate between (1) assigning each point to the nearest centroid, and (2) recomputing each centroid as the mean of its assigned points. Repeat until centroids stop moving. Each step monotonically decreases the inertia, so the algorithm always converges — though not necessarily to the global optimum.
Solution
import numpy as np
def kmeans(X: np.ndarray, k: int, max_iter: int = 100, tol: float = 1e-6, seed: int = 0):
"""Lloyd's algorithm. X is (n, d). Returns (centroids, labels)."""
rng = np.random.default_rng(seed)
n, d = X.shape
# 1) Initialize: random sample without replacement
idx = rng.choice(n, size=k, replace=False)
centroids = X[idx].copy()
for it in range(max_iter):
# 2) Assign: each point to nearest centroid
# ||x - c||^2 = ||x||^2 - 2 x.c + ||c||^2
# The ||x||^2 term doesn't affect argmin, drop it.
d2 = -2 * X @ centroids.T + (centroids ** 2).sum(axis=1) # (n, k)
labels = d2.argmin(axis=1)
# 3) Update: each centroid = mean of its assigned points
new_centroids = np.empty_like(centroids)
for j in range(k):
mask = labels == j
if mask.any():
new_centroids[j] = X[mask].mean(axis=0)
else:
# Re-seed empty cluster from a random point
new_centroids[j] = X[rng.integers(n)]
# 4) Check convergence
shift = np.linalg.norm(new_centroids - centroids)
centroids = new_centroids
if shift < tol:
break
return centroids, labels Anatomy of the assignment step
The inner loop is the only place performance matters. For
n points and k centroids, the distance
matrix is (n, k). The expansion
||x − c||² = ||x||² − 2x·c + ||c||² lets us compute it
as one matrix multiply X @ Cᵀ plus two row/column
broadcasts. The ||x||² term doesn't influence the
argmin, so we drop it.
Empty clusters are a real failure mode in vanilla k-means. If a centroid is far from everything, no points get assigned to it and its update is ill-defined. Standard fix: re-seed the empty cluster from a random data point.
k-means++ initialization
Random initialization is terrible for k-means — it routinely gets
stuck in local minima. The k-means++ scheme picks the first centroid
uniformly, then samples each subsequent centroid with probability
proportional to its squared distance from the nearest existing
centroid. This spreads the initial seeds out and gives an expected
O(log k)-approximation guarantee on the optimal
inertia.
def kmeans_pp_init(X: np.ndarray, k: int, rng):
"""k-means++ seeding: spread initial centroids."""
n = X.shape[0]
centroids = [X[rng.integers(n)]]
for _ in range(1, k):
# Distance from each point to the nearest existing centroid
d2 = np.min(
np.array([((X - c) ** 2).sum(axis=1) for c in centroids]),
axis=0,
)
probs = d2 / d2.sum()
i = rng.choice(n, p=probs)
centroids.append(X[i])
return np.array(centroids) The objective
def inertia(X, centroids, labels):
"""Sum of squared distances to each point's assigned centroid (the k-means objective)."""
return float(((X - centroids[labels]) ** 2).sum())
Plotting inertia vs k gives the "elbow" used to pick
k in practice. The bend in the curve is the
diminishing-returns point.
Complexity
- Time:
O(n · k · d)per iteration. Lloyd's algorithm is worst-case exponential in iterations, but in practice converges in a handful. - Space:
O(n · k)for the distance matrix.
Variations worth knowing
- Mini-batch k-means: on each step, use a random subset of points to update centroids with a small learning rate. Order of magnitude faster on large datasets, with negligible loss.
- k-medoids (PAM): centroids are constrained to be
data points. More robust to outliers; uses
O(n²)pairwise distances. - Soft k-means / Gaussian mixture: replace hard assignments with posterior probabilities. The natural EM generalization; recovers k-means as the zero-variance limit.
- Spectral clustering: when clusters are non-convex, run k-means on the eigenvectors of a graph Laplacian instead of on the raw points. Bends k-means to non-linear cluster shapes.
- Choosing k: elbow on inertia, silhouette score, gap statistic, or domain knowledge. There's no universal answer.