Given a binary tree root
and an integer target
, delete all the leaf nodes with value target
.
Note that once you delete a leaf node with value target
, if its parent node becomes a leaf node and has the value target
, it should also be deleted (you need to continue doing that until you cannot).
For example:
Consider the following binary tree:
1
/ \
2 3
/ \ / \
2 4 2 2
If target = 2
, the algorithm should proceed as follows:
1
/ \
_ 3
/ \ / \
_ 4 _ _
1
\
3
/
4
Consider the following binary tree:
1
/ \
3 3
/ \ / \
3 2 _ _
If target = 3
, the algorithm should return:
1
/ \
3 _
/
2
Write a function that takes the root of a binary tree and a target value as input and returns the modified binary tree after removing all leaf nodes with the specified target value, propagating the deletion upwards as needed.
Constraints:
[1, 3000]
.1 <= Node.val, target <= 1000
A brute-force approach would involve traversing the tree and identifying leaf nodes with the target value. If such a node is found, it's removed. After the removal, the tree is re-evaluated to see if the removal created new leaf nodes with the target value, repeating until no more such nodes exist.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def remove_leaf_nodes(root, target):
dummy = TreeNode(0)
dummy.left = root
def remove(node, parent, is_left):
if not node:
return None
node.left = remove(node.left, node, True)
node.right = remove(node.right, node, False)
if not node.left and not node.right and node.val == target:
if parent:
if is_left:
parent.left = None
else:
parent.right = None
return None
return node
dummy.left = remove(root, dummy, True)
return dummy.left
O(N*K), where N is the number of nodes in the tree, and K is the number of times we have to repeat the process of removing leaves. In the worst case, K could be O(N).
O(H), where H is the height of the tree, due to the recursion stack.
A more efficient solution uses a post-order traversal. We process the children of a node before processing the node itself. This allows us to identify and remove leaf nodes during the traversal, naturally propagating changes up the tree.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def remove_leaf_nodes(root, target):
if not root:
return None
root.left = remove_leaf_nodes(root.left, target)
root.right = remove_leaf_nodes(root.right, target)
if not root.left and not root.right and root.val == target:
return None
return root
O(N), where N is the number of nodes in the tree. Each node is visited exactly once.
O(H), where H is the height of the tree, due to the recursion stack. In the worst-case (skewed tree), H can be O(N), and in the best case (balanced tree), H can be O(log N).
None
).None
.