Let's explore a tree problem. Suppose you are given the root of a binary tree. Your mission is to find the maximum sum of all keys within any subtree that qualifies as a Binary Search Tree (BST). If no BST subtrees exist, return 0.
To clarify, a BST is defined by these rules:
Consider these examples to illustrate the problem:
Example 1:
Given the tree represented by the array [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]
, the correct output is 20
. This is because the BST rooted at the node with the value 5
results in the largest BST sum.
Example 2:
For the tree [4,3,null,1,2]
, the answer is 2
. The largest BST sum comes from the single node with the value 2
.
Example 3:
If the tree is [-4,-2,-5]
, the expected output is 0
. In this case, all values are negative, and no valid BST subtree has a positive sum.
Your code should be efficient and handle various test cases, including edge cases such as empty trees or trees with all negative values. How would you approach this problem?
At the heart of this problem is checking if each subtree is a valid Binary Search Tree (BST) and then computing its sum. A straightforward but inefficient solution is to traverse every node in the tree, check if the subtree rooted at that node is a BST, and if so, compute its sum. We keep track of the maximum BST sum encountered during this process.
isBST(node)
: A recursive function to determine if the subtree rooted at node
is a BST.treeSum(node)
: A recursive function to calculate the sum of all nodes in the subtree rooted at node
.isBST(node)
returns true, calculate the subtree sum using treeSum(node)
.class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def isBST(node, min_val=float('-inf'), max_val=float('inf')):
if not node:
return True
if node.val <= min_val or node.val >= max_val:
return False
return (isBST(node.left, min_val, node.val) and
isBST(node.right, node.val, max_val))
def treeSum(node):
if not node:
return 0
return node.val + treeSum(node.left) + treeSum(node.right)
def maxSumBST(root):
max_bst_sum = 0
def traverse(node):
nonlocal max_bst_sum
if not node:
return
if isBST(node):
max_bst_sum = max(max_bst_sum, treeSum(node))
traverse(node.left)
traverse(node.right)
traverse(root)
return max_bst_sum
isBST
takes O(N) time in the worst case (skewed tree) for each node.treeSum
also takes O(N) time in the worst case for each node.The space complexity is dominated by the recursion stack, which can be O(H) in the worst case, where H is the height of the tree. In the worst case (skewed tree), H = N, so the space complexity is O(N). In the average case (balanced tree), it's O(log N).
Instead of repeatedly checking for BSTs from scratch, we can use a more efficient bottom-up approach. For each node, we collect information from its children to determine if the subtree rooted at that node is a BST, as well as its sum, minimum value, and maximum value. This approach reduces redundant calculations and improves performance.
Define a recursive helper function dfs(node)
that returns a tuple:
is_bst
: Boolean indicating if the subtree is a BST.bst_sum
: The sum of the BST, or 0 if it's not a BST.min_val
: The minimum value in the subtree.max_val
: The maximum value in the subtree.Base case: If node
is None, return (True, 0, float('inf'), float('-inf'))
.
Recursively process the left and right subtrees.
Check if the current subtree is a BST:
left_is_bst
).right_is_bst
).node.val > left_max
).node.val < right_min
).If the current subtree is a BST, calculate its sum: bst_sum = node.val + left_sum + right_sum
.
Update the global max_bst_sum
if bst_sum
is greater.
Return the appropriate tuple based on whether the current subtree is a BST.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def maxSumBST(root):
max_bst_sum = 0
def dfs(node):
nonlocal max_bst_sum
if not node:
return True, 0, float('inf'), float('-inf')
left_is_bst, left_sum, left_min, left_max = dfs(node.left)
right_is_bst, right_sum, right_min, right_max = dfs(node.right)
if (left_is_bst and right_is_bst and
node.val > left_max and node.val < right_min):
current_sum = node.val + left_sum + right_sum
max_bst_sum = max(max_bst_sum, current_sum)
current_min = min(node.val, left_min)
current_max = max(node.val, right_max)
return True, current_sum, current_min, current_max
else:
return False, 0, float('-inf'), float('inf')
dfs(root)
return max_bst_sum
dfs
function visits each node exactly once.