Taro Logo

Largest BST Subtree

Medium
Meta logo
Meta
1 view
Topics:
TreesRecursion

You are given a binary tree. A Binary Search Tree (BST) is a node-based binary tree data structure which has the following properties:

  • The left subtree of a node contains only nodes with keys less than the node's key.
  • The right subtree of a node contains only nodes with keys greater than the node's key.
  • Both the left and right subtrees must also be binary search trees.

Your task is to find the largest subtree of the given binary tree that is a BST. Return the number of nodes in the largest BST subtree.

For example:

Consider the following binary tree:

      10
     /  \
    5    15
   / \   \
  1   8   7

The largest BST subtree is rooted at node 5, and it has the following structure:

    5
   / \
  1   8

It has 3 nodes. Therefore, the function should return 3.

As another example, consider:

    4
   / \
  2   7
 / \
1   3

The largest BST subtree is the entire tree, with 5 nodes. Therefore, the function should return 5.

Could you provide an efficient algorithm to solve this problem? What is the time and space complexity of your approach?

Solution


Largest BST Subtree

Naive Solution

The most straightforward approach is to traverse the binary tree and, for each node, check if the subtree rooted at that node is a BST. If it is, calculate the size of the subtree. Keep track of the largest BST subtree found so far.

Algorithm

  1. Traverse the tree in a post-order fashion.
  2. For each node, check if the subtree rooted at that node is a BST.
  3. If it is a BST, calculate the size of the subtree.
  4. Update the maximum size found so far.

Code

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_bst(root, min_val=float('-inf'), max_val=float('inf')):
    if not root:
        return True
    if root.val <= min_val or root.val >= max_val:
        return False
    return (is_bst(root.left, min_val, root.val) and
            is_bst(root.right, root.val, max_val))

def subtree_size(root):
    if not root:
        return 0
    return 1 + subtree_size(root.left) + subtree_size(root.right)

def largest_bst_subtree_naive(root):
    if not root:
        return 0

    if is_bst(root):
        return subtree_size(root)
    
    return max(largest_bst_subtree_naive(root.left), largest_bst_subtree_naive(root.right))

Time Complexity

O(N^2) in the worst case, where N is the number of nodes in the tree. For each node, is_bst could take O(N) time, and we visit each node once.

Space Complexity

O(H) due to the recursion stack, where H is the height of the tree. In the worst case, H can be N (skewed tree).

Optimal Solution

A more efficient solution involves traversing the tree in a post-order manner and, for each node, returning information about whether the subtree rooted at that node is a BST, along with its size, minimum value, and maximum value. This avoids redundant calculations.

Algorithm

  1. Traverse the tree in a post-order fashion.
  2. For each node, check if its left and right subtrees are BSTs.
  3. If both subtrees are BSTs and the node's value is within the valid range (greater than the maximum value in the left subtree and smaller than the minimum value in the right subtree), then the subtree rooted at the node is also a BST.
  4. Update the maximum size found so far.
  5. Return the size, minimum value, and maximum value of the BST subtree.

Code

class Solution:
    def largestBSTSubtree(self, root: TreeNode) -> int:
        max_size = 0

        def postorder(node):
            nonlocal max_size
            if not node:
                return 0, float('inf'), float('-inf'), True  # size, min, max, is_bst

            left_size, left_min, left_max, is_left_bst = postorder(node.left)
            right_size, right_min, right_max, is_right_bst = postorder(node.right)

            if (is_left_bst and is_right_bst and
                    node.val > left_max and node.val < right_min):
                size = left_size + right_size + 1
                max_size = max(max_size, size)
                return size, min(left_min, node.val), max(right_max, node.val), True
            else:
                return 0, 0, 0, False

        postorder(root)
        return max_size

Time Complexity

O(N), where N is the number of nodes in the tree. Each node is visited once.

Space Complexity

O(H), where H is the height of the tree, due to the recursion stack. In the worst case, H can be N (skewed tree).

Edge Cases

  • Empty tree: Should return 0.
  • Single node tree: Should return 1.
  • Skewed tree: The algorithm should still work correctly, although the space complexity might be O(N) in the worst case.
  • Tree with duplicate values: The is_bst check needs to handle the equal case properly to avoid infinite loops or incorrect results. The code above implicitly handles duplicates by using strict inequality (> and <).