Taro Logo

Find Leaves of Binary Tree

Medium
Amazon logo
Amazon
1 view
Topics:
TreesRecursion

You are given the root of a binary tree. Your task is to find and remove the leaves of the tree iteratively and return the values of the removed leaves in a list of lists. Each inner list represents the leaves removed at each iteration.

For example:

  1. If the input is an empty tree (root is None), return an empty list [].
  2. If the input is a tree with only one node (root), return a list containing a list with the root's value, i.e., [[root.val]].
  3. Consider the following binary tree:
        1
       / \
      2   3
     / \
    4   5

The expected output would be [[4, 5, 3], [2], [1]].

Explanation:

  • First iteration: Leaves are [4, 5, 3]. Remove them.
  • Second iteration: The tree becomes:
       1
      / 
     2   

Leaves are [2]. Remove it.

  • Third iteration: The tree becomes:
    1

Leaves are [1]. Remove it. The tree is now empty.

Write a function to implement this algorithm and analyze its time and space complexity. Can you describe any edge cases that might occur?

Solution


Find Leaves of Binary Tree

Problem Description

Given a binary tree, find all the leaves and then remove them. Repeat this process until the tree is empty. Return a list of lists of integers, where each inner list represents the values of the leaves removed at each step.

Naive Solution

A brute-force approach involves repeatedly traversing the tree to identify and remove the leaves. In each iteration, we identify all leaves, store their values, remove them from the tree, and repeat the process until the tree is empty.

  1. Find Leaves: Traverse the tree and identify all leaf nodes.
  2. Remove Leaves: Remove the identified leaf nodes from the tree.
  3. Repeat: Repeat steps 1 and 2 until the tree is empty.

This approach is straightforward but inefficient because we repeatedly traverse the tree in each iteration.

Code (Python):

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def findLeaves(root):
    result = []
    while root:
        leaves = []
        root = removeLeaves(root, leaves)
        result.append(leaves)
    return result

def removeLeaves(root, leaves):
    if not root:
        return None
    if not root.left and not root.right:
        leaves.append(root.val)
        return None
    root.left = removeLeaves(root.left, leaves)
    root.right = removeLeaves(root.right, leaves)
    return root

Time Complexity: O(N^2) in the worst case, where N is the number of nodes in the tree. This is because, in each iteration, we traverse the tree (O(N)), and we might have O(N) iterations if the tree is skewed.

Space Complexity: O(N) in the worst case, due to the recursion stack and the storage of leaf nodes.

Optimal Solution

A more efficient approach is to use a bottom-up traversal and assign a height to each node. The height of a node is the number of edges on the longest path from that node to a leaf. Leaves have a height of 0. We can then group the nodes by their height, as nodes with the same height will be removed in the same iteration.

  1. Calculate Height: Perform a post-order traversal of the tree. During the traversal, calculate the height of each node recursively. The height of a leaf node is 0, and the height of an internal node is max(height(left), height(right)) + 1.
  2. Group by Height: Store the node values in a list corresponding to their height.

Code (Python):

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def findLeaves(root):
    result = []
    def getHeight(node):
        if not node:
            return -1
        height = max(getHeight(node.left), getHeight(node.right)) + 1
        if len(result) <= height:
            result.append([])
        result[height].append(node.val)
        return height
    getHeight(root)
    return result

Time Complexity: O(N), where N is the number of nodes in the tree. We visit each node exactly once.

Space Complexity: O(N), where N is the number of nodes in the tree. This is due to the recursion stack and the storage of the result.

Edge Cases

  • Empty Tree: If the input tree is empty (root is None), return an empty list.
  • Single Node Tree: If the tree contains only the root node, return a list containing a list with the root's value.
  • Skewed Tree: The algorithm should handle both left-skewed and right-skewed trees correctly.

Summary

The optimal solution leverages the concept of node height to efficiently group nodes that can be removed together in each iteration. This approach significantly improves the time complexity from O(N^2) to O(N) compared to the naive solution.