Binary Tree Maximum Path Sum

Interview Prep

HardTreestreesdp

The problem

Given the root of a binary tree (with possibly negative node values), return the maximum sum of any non-empty path. A path is any sequence of nodes connected by parent–child edges; it need not pass through the root.

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

The trick: the recursion's return value is NOT the answer

This is the canonical "tree DP with two quantities" problem. At each node, you need to know two different things:

The recursive function returns the first quantity (which is what the parent needs). The second is recorded in an external best variable on the way back up. This separation is the load-bearing trick — many problems on trees have this structure (diameter, longest univalue path, lowest balanced cost, etc.).

The other key detail: negative subtree contributions are clamped to zero. Including a negative-sum subtree could only make things worse, so we replace it with an empty-path contribution. This automatically handles "what if every subtree is negative" without a special case.

Solution

def max_path_sum(root: Node | None) -> int:
    """Return the maximum path sum of any path in the tree.
    A 'path' is a sequence of nodes connected by parent-child edges; need not pass through root.
    """
    best = float('-inf')
    def gain(node: Node | None) -> int:
        nonlocal best
        if node is None:
            return 0
        left  = max(0, gain(node.left))     # ignore negative contributions
        right = max(0, gain(node.right))
        # Path THROUGH this node uses both children
        best = max(best, node.val + left + right)
        # Path going UP from this node uses at most ONE child
        return node.val + max(left, right)
    gain(root)
    return best

Trace

Tree:        -10
            /    \
           9      20
                 /  \
                15   7

gain(9):
  left=0, right=0 (no children).
  best = max(best, 9 + 0 + 0) = 9.
  return 9 + 0 = 9.

gain(15):
  left=0, right=0.
  best = max(best, 15) = 15.
  return 15.

gain(7):
  best = max(best, 7) = 15 (unchanged, 7 < 15).
  return 7.

gain(20):
  left = max(0, gain(15)) = 15.
  right = max(0, gain(7)) = 7.
  best = max(best, 20 + 15 + 7) = 42.
  return 20 + max(15, 7) = 35.

gain(-10):
  left  = max(0, gain(9))  = 9.
  right = max(0, gain(20)) = 35.
  best = max(best, -10 + 9 + 35) = 42 (unchanged).
  return -10 + max(9, 35) = 25.

return best = 42   (path 15 -> 20 -> 7)

Complexity

Variations worth knowing