Given a binary tree root
, a node X in the tree is named good if in the path from root to X there are no nodes with a value greater than X.
Return the number of good nodes in the binary tree.
For example:
Consider the binary tree:
3
/ \
1 4
/ / \
3 1 5
Nodes in blue are good.
Thus, the output should be 4.
As another example:
3
/ \
3 null
/ \
4 2
Node 2 -> (3, 3, 2) is not good, because "3" is higher than it. So the output should be 3.
What algorithm would you use to solve this problem? What is the time and space complexity of your solution?
This problem asks us to find the number of "good" nodes in a binary tree. A node is considered "good" if all nodes on the path from the root to that node have values less than or equal to it.
A naive solution would involve traversing the tree and, for each node, checking the path from the root to that node to see if any node has a value greater than the current node. This can be done using recursion or an iterative approach with a stack, keeping track of the path as we go.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def goodNodes_naive(root):
def is_good(node, path):
for val in path:
if val > node.val:
return False
return True
def traverse(node, path):
nonlocal count
if not node:
return
if is_good(node, path):
count += 1
traverse(node.left, path + [node.val])
traverse(node.right, path + [node.val])
count = 0
traverse(root, [])
return count
# Example Usage
root = TreeNode(3, TreeNode(1, TreeNode(3)), TreeNode(4, TreeNode(1), TreeNode(5)))
print(goodNodes_naive(root)) # Output: 4
root2 = TreeNode(3, TreeNode(3, TreeNode(4), TreeNode(2)))
print(goodNodes_naive(root2)) # Output: 3
root3 = TreeNode(1)
print(goodNodes_naive(root3)) # Output: 1
The optimal solution also uses recursion, but it avoids repeatedly iterating through the path. Instead, it passes down the maximum value seen so far in the path. If the current node's value is greater than or equal to the maximum value seen so far, it's a good node. We then update the maximum value and pass it down to the left and right subtrees.
def goodNodes(root):
def dfs(node, max_val):
if not node:
return 0
if node.val >= max_val:
count = 1
else:
count = 0
max_val = max(max_val, node.val)
count += dfs(node.left, max_val)
count += dfs(node.right, max_val)
return count
return dfs(root, float('-inf'))
# Example Usage
root = TreeNode(3, TreeNode(1, TreeNode(3)), TreeNode(4, TreeNode(1), TreeNode(5)))
print(goodNodes(root)) # Output: 4
root2 = TreeNode(3, TreeNode(3, TreeNode(4), TreeNode(2)))
print(goodNodes(root2)) # Output: 3
root3 = TreeNode(1)
print(goodNodes(root3)) # Output: 1
The optimal solution performs a depth-first traversal of the binary tree, visiting each node exactly once. Therefore, the time complexity is O(N), where N is the number of nodes in the tree.
The space complexity of the optimal solution is O(H), where H is the height of the binary tree. This is due to the recursive call stack. In the worst case (a skewed tree), H can be equal to N, resulting in O(N) space complexity. In the best case (a balanced tree), H is log(N), resulting in O(log N) space complexity.
max_val
is set to negative infinity to correctly handle negative node values.