You are given the root of a binary tree. Each node in the tree has a value of either 0 or 1. Your task is to prune the tree by removing all subtrees that do not contain the value 1. A subtree of a node 'node' consists of the 'node' itself and all its descendants.
Write a function that takes the root of the binary tree as input and returns the root of the pruned tree. The pruned tree should only contain nodes and subtrees that have at least one node with the value 1.
Example 1:
Consider the following binary tree:
1
\
0
/ \
0 1
After pruning, the tree should look like this:
1
\
0
\
1
The subtrees rooted at the two '0' nodes on the third level are removed because they do not contain any '1' nodes.
Example 2:
Consider the following binary tree:
1
/ \
0 1
/ \ / \
0 0 0 1
After pruning, the tree should look like this:
1
\
1
\
1
Nodes with value 0 are removed appropriately.
Constraints:
Explain your approach, analyze its time and space complexity, and handle any relevant edge cases.
Given the root of a binary tree, return the same tree where every subtree (of the given tree) not containing a 1 has been removed. A subtree of a node node
is node
plus every node that is a descendant of node
.
A naive approach would involve traversing the tree and, for each node, determining if the subtree rooted at that node contains a 1. If it doesn't, remove the subtree. This can be done recursively.
Algorithm:
containsOne(node)
that returns true
if the subtree rooted at node
contains a 1, and false
otherwise.containsOne(node)
:
node
is null
, return false
.true
if node.val
is 1 or either of the recursive calls returned true
.pruneTree(root)
:
root
is null
, return null
.containsOne(root)
is false
, return null
; otherwise, return root
.Code (Python):
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def pruneTree(root):
if not root:
return None
root.left = pruneTree(root.left)
root.right = pruneTree(root.right)
if root.val == 0 and not root.left and not root.right:
return None
return root
Time Complexity: O(N^2) in the worst case, where N is the number of nodes in the tree. This is because for each node, we might have to traverse its entire subtree to check if it contains a 1.
Space Complexity: O(H), where H is the height of the tree, due to the recursive call stack.
An optimized approach combines the pruning and the check for the presence of 1 into a single recursive function. This eliminates redundant traversals.
Algorithm:
pruneAndCheck(node)
that returns true
if the subtree rooted at node
(after pruning) contains a 1, and false
otherwise. This function also prunes the tree.pruneAndCheck(node)
:
node
is null
, return false
.node.left
and node.right
based on the results of the recursive calls.true
if node.val
is 1 or either of the recursive calls returned true
.pruneTree(root)
:
pruneAndCheck(root)
.root
if pruneAndCheck(root)
returned true
, otherwise return None
.Code (Python):
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def pruneTree(root):
def prune_and_check(node):
if not node:
return False
left_contains_one = prune_and_check(node.left)
right_contains_one = prune_and_check(node.right)
if not left_contains_one:
node.left = None
if not right_contains_one:
node.right = None
return node.val == 1 or left_contains_one or right_contains_one
if prune_and_check(root):
return root
else:
return None
Time Complexity: O(N), where N is the number of nodes in the tree, as each node is visited exactly once.
Space Complexity: O(H), where H is the height of the tree, due to the recursive call stack.
root
is null
), the function should return null
.null
.