Taro Logo

Split BST

Medium
Coupang logo
Coupang
5 views
Topics:
TreesRecursion

Given the root of a binary search tree (BST) and an integer target, split the tree into two subtrees:

  • One subtree less than or equal to the target value,
  • The other subtree greater than the target value.

Return the roots of the two subtrees after the split.

Example 1:

Input: root = [4,2,6,1,3,5,7], target = 2
Output: [[2,1],[4,3,6,null,null,5,7]]

Example 2:

Input: root = [1,null,3,0,2,4,null], target = 4
Output: [[1,null,3,0,2],[4]]

Constraints:

  • The number of nodes in the tree is in the range [1, 500].
  • 0 <= Node.val <= 1000
  • All the values in the tree are unique.
  • target >= 0

Solution


Clarifying Questions

When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:

  1. What data type is the value 'V' used to split the BST? Can I assume it's an integer?
  2. What should I return if the input BST is empty (null)?
  3. Is 'V' guaranteed to be present in the BST, or do I need to handle the case where 'V' is smaller than the smallest element or larger than the largest element in the BST?
  4. Should the returned BSTs (the one with nodes <= V and the one with nodes > V) be valid BSTs after the split?
  5. Can I modify the original BST or should I create entirely new nodes for the two resulting BSTs?

Brute Force Solution

Approach

The brute force approach to splitting a Binary Search Tree (BST) at a certain value involves examining all possible ways to separate the tree. We consider every node as a potential splitting point and recursively explore all resulting subtrees. Ultimately, we pick the split that adheres to the BST properties and the splitting value.

Here's how the algorithm would work step-by-step:

  1. Imagine that you are looking at each node in the tree, one by one, and considering it as the splitting point.
  2. For each node, pretend that you cut the tree at that point. This creates two potential new trees.
  3. One tree will contain all nodes that are less than or equal to the value where you cut.
  4. The other tree will contain all nodes that are greater than the value where you cut.
  5. Check if the two resulting structures are actually valid Binary Search Trees, adhering to the correct order (left is less than, right is greater than).
  6. If they are, you've found one possible way to split the tree.
  7. Keep trying this for every single node in the original tree.
  8. After testing all nodes, review the different possible splits that you found.
  9. From the valid splits, select the one that best satisfies any other specific requirements, such as minimizing the number of nodes in one of the resulting trees or whatever else the question is asking.

Code Implementation

def split_bst_brute_force(root, value):
    possible_splits = []

    def is_bst(root_node, min_value=float('-inf'), max_value=float('inf')):
        if not root_node:
            return True
        if not (min_value < root_node.val < max_value):
            return False
        return (is_bst(root_node.left, min_value, root_node.val)
                and is_bst(root_node.right, root_node.val, max_value))

    def split_tree(root_node, current_value):
        less_equal_tree = None
        greater_tree = None

        if not root_node:
            return None, None

        # Create a new tree containing elements <= value
        def build_less_equal_tree(node):
            if not node: 
                return None
            new_node = TreeNode(node.val)
            new_node.left = build_less_equal_tree(node.left)
            new_node.right = build_less_equal_tree(node.right)
            return new_node

        # Create a new tree containing elements > value
        def build_greater_tree(node):
            if not node:
                return None
            new_node = TreeNode(node.val)
            new_node.left = build_greater_tree(node.left)
            new_node.right = build_greater_tree(node.right)
            return new_node

        less_equal_root = build_less_equal_tree(root)
        greater_root = build_greater_tree(root)
        
        less_equal_nodes = []
        greater_nodes = []
        
        def inorder_traversal(node, result_list):
            if node:
                inorder_traversal(node.left, result_list)
                result_list.append(node.val)
                inorder_traversal(node.right, result_list)

        inorder_traversal(less_equal_root, less_equal_nodes)
        inorder_traversal(greater_root, greater_nodes)
        
        less_equal_tree_valid = True
        greater_tree_valid = True

        for node_value in less_equal_nodes:
            if node_value > current_value:
                less_equal_tree_valid = False
                break
        for node_value in greater_nodes:
            if node_value <= current_value:
                greater_tree_valid = False
                break

        less_equal_tree = None
        greater_tree = None

        def create_tree(node_values):
            if not node_values:
                return None

            root_node = TreeNode(node_values[0])
            for value in node_values[1:]:
                current = root_node
                while True:
                    if value < current.val:
                        if current.left is None:
                            current.left = TreeNode(value)
                            break
                        else:
                            current = current.left
                    else:
                        if current.right is None:
                            current.right = TreeNode(value)
                            break
                        else:
                            current = current.right
            return root_node

        less_equal_tree = create_tree(less_equal_nodes)
        greater_tree = create_tree(greater_nodes)

        return less_equal_tree, greater_tree

    # Need to examine every single node in the tree
    def traverse(root_node):
        if not root_node:
            return

        # Create the split when examining the current node
        less_equal, greater = split_tree(root, root_node.val)

        if is_bst(less_equal) and is_bst(greater):
            # Store the results if they are valid BST
            possible_splits.append((less_equal, greater))

        traverse(root_node.left)
        traverse(root_node.right)

    traverse(root)

    best_split = None
    min_nodes = float('inf')

    # Find the best split of all the possible splits.
    for less_equal_tree, greater_tree in possible_splits:
        less_equal_node_count = count_nodes(less_equal_tree)
        # We are trying to minimize the number of nodes
        if less_equal_node_count < min_nodes:
            min_nodes = less_equal_node_count
            best_split = (less_equal_tree, greater_tree)

    return best_split if best_split else (None, None)

def count_nodes(root_node):
    if not root_node:
        return 0
    return 1 + count_nodes(root_node.left) + count_nodes(root_node.right)

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

Big(O) Analysis

Time Complexity
O(n²)The algorithm considers each of the n nodes in the BST as a potential splitting point. For each of these nodes, it implicitly explores the entire tree to construct and validate the resulting subtrees. The validation step, which checks if the subtrees are valid BSTs, also takes O(n) time in the worst case, as it might need to traverse each node. Therefore, since we perform an O(n) operation for each of the n nodes, the total time complexity is approximately n * n, which simplifies to O(n²).
Space Complexity
O(N)The brute force approach, as described, considers each node as a potential splitting point and recursively explores the resulting subtrees. The recursion depth, in the worst-case scenario, can be proportional to the number of nodes in the tree, N. Each recursive call adds a stack frame to the call stack, consuming memory. Therefore, the auxiliary space used by the recursion stack is O(N) in the worst-case (skewed tree).

Optimal Solution

Approach

The problem asks us to divide a Binary Search Tree (BST) into two BSTs, one containing nodes with values less than or equal to a given value, and the other containing the remaining nodes. The clever idea is to recursively traverse the tree, making use of the BST property at each node to efficiently split it.

Here's how the algorithm would work step-by-step:

  1. Start at the root of the BST.
  2. If the current node's value is less than or equal to the given value, then keep this node in the 'less than or equal to' tree. Importantly, its left subtree also belongs to this tree. Recursively split the right subtree of the current node; the left part of that split (nodes less than or equal to the target value) becomes the right child of the current node, and the right part of that split goes to the 'greater than' tree.
  3. If the current node's value is greater than the given value, keep this node in the 'greater than' tree. Symmetrically, its right subtree also belongs to this tree. Recursively split the left subtree of the current node; the right part of that split (nodes greater than the target value) becomes the left child of the current node, and the left part of that split goes to the 'less than or equal to' tree.
  4. The recursion naturally handles connecting the subtrees correctly, and the BST property ensures that we only consider the relevant parts of the tree at each step.
  5. When you reach a null node, simply return two empty trees.
  6. The base cases and recursive calls efficiently build the two new BSTs.

Code Implementation

def split_bst(root, target_value):
    if not root:
        return None, None

    if root.val <= target_value:
        # Keep root in the 'less than or equal' tree
        less_than_or_equal_to_tree, greater_than_tree = split_bst(root.right, target_value)
        root.right = less_than_or_equal_to_tree

        return root, greater_than_tree

    else:
        # Keep root in the 'greater than' tree
        less_than_or_equal_to_tree, greater_than_tree = split_bst(root.left, target_value)
        root.left = greater_than_tree
        # Attach the right sub tree found during splitting to the left of root.

        return less_than_or_equal_to_tree, root

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def construct_bst():
    root = TreeNode(4)
    root.left = TreeNode(2)
    root.right = TreeNode(6)
    root.left.left = TreeNode(1)
    root.left.right = TreeNode(3)
    root.right.left = TreeNode(5)
    root.right.right = TreeNode(7)
    return root

if __name__ == '__main__':
    root = construct_bst()
    target_value = 3
    less_than_or_equal_to_tree, greater_than_tree = split_bst(root, target_value)

    #Can't be tested because it's a visual tree, no assertion
    print("BST Splitted")

Big(O) Analysis

Time Complexity
O(log n)The algorithm traverses the Binary Search Tree, making decisions based on the value of each node relative to the given value. In a balanced BST, each recursive call effectively halves the search space, similar to binary search. Therefore, the height of the tree determines the number of recursive calls. In the average case (balanced BST), the height is log n, leading to a time complexity of O(log n). In the worst-case scenario (skewed BST), the height can be n, and the time complexity degrades to O(n).
Space Complexity
O(H)The space complexity is determined by the recursion depth. In the worst case, the recursion depth can be equal to the height of the BST, denoted as H, as the split function is recursively called for each node along a path from the root to a leaf. Therefore, the auxiliary space used by the call stack is proportional to H. In a balanced BST, H would be log(N), while in a skewed BST, H could be N, where N is the number of nodes in the BST.

Edge Cases

CaseHow to Handle
Null root or empty treeReturn [None, None] immediately as there's nothing to split.
Value is smaller than the smallest element in BSTReturn [None, root] as the left subtree will be empty.
Value is larger than the largest element in BSTReturn [root, None] as the right subtree will be empty.
Tree with only one nodeIf the node's value is less than or equal to 'val', return [root, None], otherwise return [None, root].
Tree with highly skewed structure (e.g., linked list)The recursive calls will still traverse the tree, but performance might degrade to O(n) where n is the number of nodes.
Value exists in the tree; multiple nodes equal to 'val'Split based on the first encountered node equal to 'val' during the traversal; behavior will be consistent.
Integer overflow during value comparisons if values are close to MAX_INT or MIN_INTEnsure value is cast to a wider type or use libraries that handle the overflow for comparing to avoid erroneous comparisons.
Deeply unbalanced tree exceeding recursion depth limitConsider converting to iterative approach to avoid stack overflow with extremely large or unbalanced trees.