Given a binary tree root
, return the maximum sum of all keys of any sub-tree which is also a Binary Search Tree (BST). Assume a BST is defined as follows:
For example, given the binary tree [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]
, the expected output is 20
. The maximum sum is obtained in root node with key equal to 3
.
As another example, given the binary tree [4,3,null,1,2]
, the expected output is 2
. The maximum sum is obtained in a single root node with key equal to 2
.
What is the most efficient algorithm to solve this problem, and what is its time and space complexity?
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def maxSumBST(self, root: TreeNode) -> int:
max_sum = 0
def traverse(node):
nonlocal max_sum
# Base case: empty node
if not node:
return float('inf'), float('-inf'), 0, True # min, max, sum, is_bst
# Recursive calls for left and right subtrees
left_min, left_max, left_sum, left_bst = traverse(node.left)
right_min, right_max, right_sum, right_bst = traverse(node.right)
# Check if the current subtree is a BST
if left_bst and right_bst and left_max < node.val < right_min:
current_sum = left_sum + right_sum + node.val
max_sum = max(max_sum, current_sum)
current_min = min(left_min, node.val)
current_max = max(right_max, node.val)
return current_min, current_max, current_sum, True
else:
# Not a BST, return appropriate values.
return float('-inf'), float('inf'), 0, False
traverse(root)
return max_sum
The problem requires finding the maximum sum of all keys in any subtree that is a Binary Search Tree (BST). The solution involves a recursive traversal of the binary tree.
TreeNode
Class: Represents a node in the binary tree.maxSumBST(root)
Function:
max_sum
to 0. This variable will store the maximum BST sum found.traverse
function to recursively explore the tree.max_sum
.traverse(node)
Function:
node
is None
, it returns (inf, -inf, 0, True)
. inf
and -inf
are used for min/max to not interfere with other nodes, 0
represents sum of an empty tree, and True
signifies that an empty tree is a BST.traverse
on the left and right subtrees to get their min/max values, sum, and BST status.left_bst and right_bst and left_max < node.val < right_min
.
current_sum
, updates max_sum
, and returns updated min/max values and the current_sum
, along with True
indicating it's a BST.(-inf, inf, 0, False)
indicating that this subtree cannot form a valid BST.For the input [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]
:
The traverse
function will recursively explore the tree, identify BST subtrees (like the subtree rooted at 2), compute their sums, and update max_sum
accordingly. Finally max_sum
which is the largest BST sum is returned.
traverse
function visits each node in the tree exactly once.traverse
function.max_sum
will remain 0, which is the correct result.left_max < node.val < right_min
holds strictly. If the tree can contain duplicates, it could affect the correctness of the BST identification.