Median of Two Sorted Arrays

Interview Prep

LegendaryBinary Searchbinary-searchinvariants

The problem

Given two sorted arrays a and b of total length n + m, return the median of the combined sorted array — without actually merging. The required time is O(log min(n, m)). This is the canonical "I can't believe Google asks this" problem.

Pattern: binary search on the partition

Don't search for the median value. Search for the right partition: a place to split each array so that the left halves of both arrays together contain exactly the smaller half of all elements, and the right halves contain the larger half. Once that's right, the median is read off from the four boundary values.

Binary search lives on the smaller array — that's the log min(n, m) bound. For each candidate split i in a, the matching split j in b is forced (so total left-side count equals (n + m + 1) / 2). Check the partition; adjust the range; repeat.

Brute force — merge and pick

def find_median(a: list[int], b: list[int]) -> float:
    merged: list[int] = []
    i = j = 0
    while i < len(a) and j < len(b):
        if a[i] <= b[j]:
            merged.append(a[i]); i += 1
        else:
            merged.append(b[j]); j += 1
    merged.extend(a[i:])
    merged.extend(b[j:])

    n = len(merged)
    if n % 2 == 1:
        return float(merged[n // 2])
    return (merged[n // 2 - 1] + merged[n // 2]) / 2

O(n + m) time, O(n + m) space. The "obvious" answer; correct, but misses the point of the question. Acceptable as a 30-second "I'll come back to this" if you flag the better bound exists.

Optimal — binary search on partitions

def find_median(a: list[int], b: list[int]) -> float:
    # Always binary-search the SHORTER array, so the search range is small.
    if len(a) > len(b):
        a, b = b, a

    n, m = len(a), len(b)
    half = (n + m + 1) // 2          # # of elements in the LEFT partition

    lo, hi = 0, n
    while lo <= hi:
        i = (lo + hi) // 2           # take i from a, j from b, summing to half
        j = half - i

        a_left  = a[i - 1] if i > 0 else float('-inf')
        a_right = a[i]     if i < n else float('inf')
        b_left  = b[j - 1] if j > 0 else float('-inf')
        b_right = b[j]     if j < m else float('inf')

        if a_left <= b_right and b_left <= a_right:
            # Correct partition.
            if (n + m) % 2 == 1:
                return float(max(a_left, b_left))
            return (max(a_left, b_left) + min(a_right, b_right)) / 2
        elif a_left > b_right:
            hi = i - 1               # took too many from a
        else:
            lo = i + 1               # took too few from a

    raise ValueError("inputs not sorted")

O(log min(n, m)) time, O(1) space. Always swap so a is the shorter array — that controls the log factor and prevents j from going negative.

The four boundary values a_left, a_right, b_left, b_right form a 2×2 around the partition. The partition is correct exactly when a_left ≤ b_right and b_left ≤ a_right — every left-half element is no greater than every right-half element across both arrays. The infinity sentinels handle empty sides cleanly without special cases.

The hardest part is the off-by-one in half. Using (n + m + 1) // 2 means the left side gets the extra element when the total is odd, so the median equals max(a_left, b_left). Using (n + m) // 2 would put the extra on the right and require swapping which max/min you read for odd totals. Stick with the formulation above.

Walkthrough

a = [1, 3], b = [2, 4]
n = 2, m = 2, half = 2

i=1 j=1
  a_left=1  a_right=3
  b_left=2  b_right=4
  1 ≤ 4 and 2 ≤ 3 → partition correct
  total even → (max(1, 2) + min(3, 4)) / 2 = (2 + 3) / 2 = 2.5

Edge cases

Related