Validate Binary Search Tree

Interview Prep

StandardTreestreesrecursioninvariants

The problem

Given the root of a binary tree, decide whether it is a valid binary search tree. A BST requires that for every node, all values in its left subtree are strictly less, and all values in its right subtree are strictly greater. The "strictly" matters — a duplicate breaks the property.

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

The trap: local check isn't enough

The naive idea is to check, at each node, that its left child is smaller and its right child is larger. That's necessary but not sufficient. A grandchild can sit on the wrong side of an ancestor while still satisfying the local rule with its own parent.

def is_bst_wrong(root):
    """WRONG — checks only immediate children."""
    if root is None: return True
    if root.left  and root.left.val  >= root.val: return False
    if root.right and root.right.val <= root.val: return False
    return is_bst_wrong(root.left) and is_bst_wrong(root.right)

# Fails on:        5
#                /   \
#               1     6
#                    / \
#                   3   7    <- 3 < 5, but it's in 5's RIGHT subtree, so invalid

Pattern: recursion with bounds

Every node has an allowed interval (lo, hi) that it must lie strictly inside. The root starts with (−∞, +∞). Recursing left tightens the upper bound to the parent's value; recursing right tightens the lower bound. This is the load-bearing template for any tree problem that has a "global" constraint expressed pointwise.

Solution: bounds recursion

def is_bst(root) -> bool:
    def walk(node, lo, hi) -> bool:
        if node is None: return True
        if not (lo < node.val < hi): return False
        return walk(node.left,  lo, node.val) and \
               walk(node.right, node.val, hi)
    return walk(root, float('-inf'), float('inf'))

Trace

Tree:        5
            / \
           1   6
              / \
             3   7

walk(5, -inf, +inf):  -inf < 5 < +inf ✓
  walk(1, -inf, 5):   -inf < 1 < 5    ✓  (no children)
  walk(6, 5, +inf):    5  < 6 < +inf  ✓
    walk(3, 5, 6):     3 not > 5      ✗  -> returns False

Naive (immediate-children-only) walker would have seen:
  at 6: left=3 < 6 ✓, right=7 > 6 ✓
  ...and missed the global violation.

Alternative: in-order traversal

An in-order traversal of a valid BST visits values in strictly increasing order. So: walk in-order, keep a running prev, and check each node is greater. Just as O(n), slightly more memory in stack frames but no bound arithmetic.

def is_bst_inorder(root) -> bool:
    """In-order traversal of a BST is strictly increasing."""
    prev = [float('-inf')]
    def walk(node):
        if node is None: return True
        if not walk(node.left): return False
        if node.val <= prev[0]: return False
        prev[0] = node.val
        return walk(node.right)
    return walk(root)

Complexity

Variations worth knowing