Taro Logo

Maximum Sum BST in Binary Tree

Hard
Amazon logo
Amazon
1 view
Topics:
TreesRecursionDynamic Programming

Let's explore a tree problem. Suppose you are given the root of a binary tree. Your mission is to find the maximum sum of all keys within any subtree that qualifies as a Binary Search Tree (BST). If no BST subtrees exist, return 0.

To clarify, a BST is defined by these rules:

  1. The left subtree of a node contains only nodes with keys that are strictly less than the node's key.
  2. The right subtree of a node contains only nodes with keys that are strictly greater than the node's key.
  3. Both the left and right subtrees must themselves be binary search trees.

Consider these examples to illustrate the problem:

Example 1:

Given the tree represented by the array [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6], the correct output is 20. This is because the BST rooted at the node with the value 5 results in the largest BST sum.

Example 2:

For the tree [4,3,null,1,2], the answer is 2. The largest BST sum comes from the single node with the value 2.

Example 3:

If the tree is [-4,-2,-5], the expected output is 0. In this case, all values are negative, and no valid BST subtree has a positive sum.

Your code should be efficient and handle various test cases, including edge cases such as empty trees or trees with all negative values. How would you approach this problem?

Solution


Naive Approach: Brute Force

At the heart of this problem is checking if each subtree is a valid Binary Search Tree (BST) and then computing its sum. A straightforward but inefficient solution is to traverse every node in the tree, check if the subtree rooted at that node is a BST, and if so, compute its sum. We keep track of the maximum BST sum encountered during this process.

Algorithm

  1. isBST(node): A recursive function to determine if the subtree rooted at node is a BST.
  2. treeSum(node): A recursive function to calculate the sum of all nodes in the subtree rooted at node.
  3. Traverse the tree. For each node:
    • If isBST(node) returns true, calculate the subtree sum using treeSum(node).
    • Update the maximum BST sum if the current subtree sum is greater.

Code (Python)

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def isBST(node, min_val=float('-inf'), max_val=float('inf')):
    if not node:
        return True
    
    if node.val <= min_val or node.val >= max_val:
        return False

    return (isBST(node.left, min_val, node.val) and
            isBST(node.right, node.val, max_val))

def treeSum(node):
    if not node:
        return 0
    return node.val + treeSum(node.left) + treeSum(node.right)

def maxSumBST(root):
    max_bst_sum = 0
    
    def traverse(node):
        nonlocal max_bst_sum
        if not node:
            return
        
        if isBST(node):
            max_bst_sum = max(max_bst_sum, treeSum(node))
            
        traverse(node.left)
        traverse(node.right)
    
    traverse(root)
    return max_bst_sum

Time Complexity

  • isBST takes O(N) time in the worst case (skewed tree) for each node.
  • treeSum also takes O(N) time in the worst case for each node.
  • We call these functions for each node in the tree.
  • Therefore, the overall time complexity is O(N*N) = O(N^2) in the worst case, where N is the number of nodes.

Space Complexity

The space complexity is dominated by the recursion stack, which can be O(H) in the worst case, where H is the height of the tree. In the worst case (skewed tree), H = N, so the space complexity is O(N). In the average case (balanced tree), it's O(log N).

Optimal Approach: Bottom-Up Traversal with Information Passing

Instead of repeatedly checking for BSTs from scratch, we can use a more efficient bottom-up approach. For each node, we collect information from its children to determine if the subtree rooted at that node is a BST, as well as its sum, minimum value, and maximum value. This approach reduces redundant calculations and improves performance.

Algorithm

  1. Define a recursive helper function dfs(node) that returns a tuple:

    • is_bst: Boolean indicating if the subtree is a BST.
    • bst_sum: The sum of the BST, or 0 if it's not a BST.
    • min_val: The minimum value in the subtree.
    • max_val: The maximum value in the subtree.
  2. Base case: If node is None, return (True, 0, float('inf'), float('-inf')).

  3. Recursively process the left and right subtrees.

  4. Check if the current subtree is a BST:

    • The left subtree must be a BST (left_is_bst).
    • The right subtree must be a BST (right_is_bst).
    • The current node's value must be greater than the maximum value in the left subtree (node.val > left_max).
    • The current node's value must be less than the minimum value in the right subtree (node.val < right_min).
  5. If the current subtree is a BST, calculate its sum: bst_sum = node.val + left_sum + right_sum. Update the global max_bst_sum if bst_sum is greater.

  6. Return the appropriate tuple based on whether the current subtree is a BST.

Code (Python)

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def maxSumBST(root):
    max_bst_sum = 0

    def dfs(node):
        nonlocal max_bst_sum

        if not node:
            return True, 0, float('inf'), float('-inf')

        left_is_bst, left_sum, left_min, left_max = dfs(node.left)
        right_is_bst, right_sum, right_min, right_max = dfs(node.right)

        if (left_is_bst and right_is_bst and
            node.val > left_max and node.val < right_min):
            
            current_sum = node.val + left_sum + right_sum
            max_bst_sum = max(max_bst_sum, current_sum)
            
            current_min = min(node.val, left_min)
            current_max = max(node.val, right_max)
            
            return True, current_sum, current_min, current_max
        else:
            return False, 0, float('-inf'), float('inf')

    dfs(root)
    return max_bst_sum

Time Complexity

  • The dfs function visits each node exactly once.
  • Therefore, the overall time complexity is O(N), where N is the number of nodes in the tree.

Space Complexity

  • The space complexity is dominated by the recursion stack, which can be O(H) in the worst case, where H is the height of the tree. In the worst case (skewed tree), H = N, so the space complexity is O(N). In the average case (balanced tree), it's O(log N).

Edge Cases and Considerations

  1. Empty Tree: The code handles empty trees gracefully, as the base case returns a valid BST with a sum of 0.
  2. Negative Values: The algorithm works correctly with negative values in the tree.
  3. All Negative Values: If all nodes have negative values, the algorithm will correctly identify the BST with the maximum sum (which might be a single node or an empty BST with a sum of 0).
  4. Duplicate Values: The BST definition requires that left subtree nodes are less than the node's value, and right subtree nodes are greater than the node's value. If duplicate values are allowed in BST, this logic must be modified accordingly (e.g., allowing left subtree to be less than or equal to).