Given the root of a binary tree, return the number of nodes where the value of the node is equal to the average of the values in its subtree.
Note:
For example:
Consider the following binary tree:
4
/ \
8 5
/ \ \
0 1 6
In this example:
So, the nodes with values 4, 5, 0, 1, and 6 satisfy the condition.
Could you implement a function that takes the root of a binary tree as input and returns the number of nodes where the node's value equals the average of its subtree's values?
When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:
The brute force method involves exploring every subtree of the given tree. For each subtree, we calculate the sum of its nodes and the total number of nodes, then compare the average with the root node's value to see if they match.
Here's how the algorithm would work step-by-step:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def count_nodes_equal_to_average_of_subtree(root):
count = 0
def subtree_sum_and_count(node):
nonlocal count
if not node:
return 0, 0
# Recursively calculate sum and count for left and right subtrees
left_sum, left_count = subtree_sum_and_count(node.left)
right_sum, right_count = subtree_sum_and_count(node.right)
subtree_sum = node.val + left_sum + right_sum
subtree_count = 1 + left_count + right_count
# Check if the average of the subtree is equal to the node's value.
if node.val == subtree_sum // subtree_count:
count += 1
return subtree_sum, subtree_count
subtree_sum_and_count(root)
return count
To efficiently count nodes with values equal to the average of their subtree, we'll traverse the tree in a specific order. During this traversal, we'll calculate the sum of node values and the count of nodes within each subtree, all in a single pass.
Here's how the algorithm would work step-by-step:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def count_nodes_equal_to_average_of_subtree(root):
number_of_nodes_equal_to_average = 0
def postorder_traversal(node):
nonlocal number_of_nodes_equal_to_average
if not node:
return 0, 0
left_subtree_sum, left_subtree_node_count = postorder_traversal(node.left)
right_subtree_sum, right_subtree_node_count = postorder_traversal(node.right)
subtree_sum = node.val + left_subtree_sum + right_subtree_sum
subtree_node_count = 1 + left_subtree_node_count + right_subtree_node_count
# Compare current node's value with subtree average
if node.val == subtree_sum // subtree_node_count:
number_of_nodes_equal_to_average += 1
return subtree_sum, subtree_node_count
postorder_traversal(root)
# Return the final count of nodes
return number_of_nodes_equal_to_average
Case | How to Handle |
---|---|
Null root (empty tree) | Return 0 if the root is null as there are no nodes to check. |
Single node tree | Return 1 if the root is the only node, as the node's value is trivially equal to the subtree's average. |
Tree with all identical values | The algorithm should correctly count all nodes as their values will always match the subtree average. |
Tree with only negative values | The algorithm should handle negative values correctly during summation and averaging. |
Tree with very large positive or negative values leading to potential integer overflow. | Use a data type with a larger range, such as long, for sum calculations to prevent integer overflow. |
Unbalanced/skewed tree | The recursive solution's call stack might become deep; however, given the problem's typical constraints, stack overflow is unlikely, but iterative solutions avoid this altogether. |
Tree with a mix of positive, negative, and zero values. | The algorithm should handle zeros correctly during summation and averaging without causing division by zero errors (ensure subtree size is never zero before averaging). |
Maximum tree size causing memory constraints or slow execution. | The algorithm should have a time complexity of O(N) where N is the number of nodes, making it efficient; also, avoid unnecessary memory allocations. |