Kth Largest Element

Interview Prep

StandardHeap / Priority Queueheapquickselect

The problem

Given an unsorted array of integers and an integer k, return the k-th largest element. "Largest" is in sorted-descending order; k = 1 is the maximum. Equivalent problem: k-th smallest if you flip the comparison.

Three solutions worth knowing

This problem is special because three very different algorithms each apply, with different trade-offs.

Baseline: sort

def kth_largest_sort(nums: list[int], k: int) -> int:
    """O(n log n).  Trivial baseline."""
    return sorted(nums, reverse=True)[k - 1]

Heap of size k

import heapq

def kth_largest_heap(nums: list[int], k: int) -> int:
    """O(n log k).  Maintain a min-heap of size k; its root is the k-th largest."""
    h = []
    for x in nums:
        if len(h) < k:
            heapq.heappush(h, x)
        elif x > h[0]:
            heapq.heapreplace(h, x)         # pop + push in one step
    return h[0]

Quickselect (expected O(n))

Quickselect is quicksort that only recurses into the side containing the target index. The expected number of comparisons is linear because the recurrence T(n) = T(n/2) + n sums to 2n. Pivot selection matters: a deterministic "first element" pivot causes O(n²) on sorted input; a random pivot makes the worst case astronomically unlikely. Median-of-medians selects a pivot deterministically in linear time and gives O(n) worst-case — beautiful theoretically, rarely used in practice because of the constant factor.

import random

def kth_largest(nums: list[int], k: int) -> int:
    """O(n) average via quickselect.  Find element at index k-1 in a sorted-desc view."""
    target = k - 1                          # 0-indexed position in the sorted-desc array
    lo, hi = 0, len(nums) - 1
    a = nums[:]                             # don't mutate caller's array

    while True:
        pivot_idx = partition(a, lo, hi)
        if pivot_idx == target:
            return a[pivot_idx]
        if pivot_idx < target:
            lo = pivot_idx + 1
        else:
            hi = pivot_idx - 1

def partition(a, lo, hi):
    """Lomuto-style partition around a RANDOM pivot, sorting DESC.
    Returns the final index of the pivot.
    Elements left of pivot >= pivot value; elements right of pivot < pivot value.
    """
    p = random.randint(lo, hi)
    a[p], a[hi] = a[hi], a[p]
    pivot = a[hi]
    store = lo
    for i in range(lo, hi):
        if a[i] > pivot:                    # use >= to handle duplicates more evenly
            a[i], a[store] = a[store], a[i]; store += 1
    a[store], a[hi] = a[hi], a[store]
    return store

Trace

nums = [3, 2, 1, 5, 6, 4], k = 2

Heap walk (k=2):
  x=3: h=[3]
  x=2: h=[2, 3]      (size 2)
  x=1: 1 <= h[0]=2, skip
  x=5: 5 > 2, replace top: h=[3, 5]
  x=6: 6 > 3, replace top: h=[5, 6]
  x=4: 4 <= 5, skip

return h[0] = 5    (2nd largest)

Quickselect (target index 1, i.e. 2nd largest):
  partition entire array around some pivot, say 4:
    rearranged to [5, 6, 4, 1, 2, 3] (>= 4 on left, < 4 on right; pivot at index 2)
    pivot_idx = 2, target = 1.   pivot_idx > target -> hi = 1.
  partition [5, 6] around some pivot, say 6:
    [6, 5], pivot 6 ends at index 0.
    pivot_idx = 0, target = 1.   pivot_idx < target -> lo = 1.
  partition [5] around 5:
    pivot_idx = 1, target = 1.   match -> return a[1] = 5.

Complexity summary

Variations worth knowing