Lowest Common Ancestor of a Binary Tree

Interview Prep

StandardTreestreesrecursion

The problem

Given the root of a binary tree and two nodes p and q guaranteed to be present, return the lowest node that has both p and q in its subtree (a node is allowed to be a descendant of itself).

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

Pattern: post-order, "found here or below"

Define lca(root, p, q) to return:

Then the recursion is mechanical. The base case returns root directly when it equals p or q (or is None). The combine step: if both subtrees returned non-null, p and q are on different sides — so root is the LCA. Otherwise the non-null child carries forward whatever was found below.

Solution

def lca(root: Node | None, p: Node, q: Node) -> Node | None:
    if root is None or root is p or root is q:
        return root
    left  = lca(root.left,  p, q)
    right = lca(root.right, p, q)
    if left and right:    # p and q in different subtrees -> current root is the LCA
        return root
    return left if left else right

Trace

Tree:        3
            /   \
           5     1
          / \   / \
         6   2 0   8
            / \
           7   4

Find LCA of 5 and 4.

lca(3, 5, 4):
  3 is neither 5 nor 4 -> recurse.
  left = lca(5, 5, 4):
    root is 5 (== p) -> return 5.
  right = lca(1, 5, 4):
    left  = lca(0, ...) = None
    right = lca(8, ...) = None
    -> return None.
  left=5, right=None -> return 5.

Result: 5  (5 is itself an ancestor of 4).

BST shortcut

If the tree is a BST, you can use the ordering instead of full traversal. Walk down: if both p and q are smaller than the current node, go left; if both are larger, go right; otherwise the current node is the split point and the LCA. O(h) time, O(1) space.

def lca_bst(root, p, q):
    """For a BST, walk down using the ordering."""
    while root:
        if   p.val < root.val and q.val < root.val: root = root.left
        elif p.val > root.val and q.val > root.val: root = root.right
        else: return root                           # split point

Complexity

Variations worth knowing