Taro Logo

Count Nodes Equal to Average of Subtree

Medium
Meta logo
Meta
6 views
Topics:
TreesRecursion

Given the root of a binary tree, return the number of nodes where the value of the node is equal to the average of the values in its subtree.

Note:

  • The average of n elements is the sum of the n elements divided by n and rounded down to the nearest integer.
  • A subtree of root is a tree consisting of root and all of its descendants.

For example:

Consider the following binary tree:

      4
     / \
    8   5
   / \   \
  0   1   6

In this example:

  • For the node with value 4: The average of its subtree is (4 + 8 + 5 + 0 + 1 + 6) / 6 = 24 / 6 = 4.
  • For the node with value 8: The average of its subtree is (8 + 0 + 1) / 3 = 9 / 3 = 3.
  • For the node with value 5: The average of its subtree is (5 + 6) / 2 = 11 / 2 = 5.
  • For the node with value 0: The average of its subtree is 0 / 1 = 0.
  • For the node with value 1: The average of its subtree is 1 / 1 = 1.
  • For the node with value 6: The average of its subtree is 6 / 1 = 6.

So, the nodes with values 4, 5, 0, 1, and 6 satisfy the condition.

Could you implement a function that takes the root of a binary tree as input and returns the number of nodes where the node's value equals the average of its subtree's values?

Solution


Clarifying Questions

When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:

  1. What is the range of values for the nodes in the binary tree? Can they be negative or non-integer?
  2. What should I return if the root is null (an empty tree)?
  3. Can the number of nodes in the tree be very large, and should I be mindful of potential stack overflow issues due to recursion?
  4. By 'average,' do you mean the integer division result (truncating the decimal part) or the floating-point average?
  5. If a node's value is equal to the average of its subtree, should I still continue to explore that subtree or can I stop?

Brute Force Solution

Approach

The brute force method involves exploring every subtree of the given tree. For each subtree, we calculate the sum of its nodes and the total number of nodes, then compare the average with the root node's value to see if they match.

Here's how the algorithm would work step-by-step:

  1. Start at the very top of the tree.
  2. Consider this node as the root of a subtree.
  3. Calculate the total value of all nodes within this subtree, including the root.
  4. Count how many nodes are in this subtree.
  5. Divide the total value by the number of nodes to find the average value for this subtree.
  6. Check if the average value you just computed is equal to the value of the root node of that subtree.
  7. If they are equal, increase a special counter by one.
  8. Now, repeat this exact same process for every single node in the tree, considering each one as the root of a subtree.
  9. When you have checked every single node in the tree, the value of the counter will be the answer.

Code Implementation

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def count_nodes_equal_to_average_of_subtree(root):
    count = 0

    def subtree_sum_and_count(node):
        nonlocal count
        if not node:
            return 0, 0

        # Recursively calculate sum and count for left and right subtrees
        left_sum, left_count = subtree_sum_and_count(node.left)
        right_sum, right_count = subtree_sum_and_count(node.right)

        subtree_sum = node.val + left_sum + right_sum
        subtree_count = 1 + left_count + right_count

        # Check if the average of the subtree is equal to the node's value.
        if node.val == subtree_sum // subtree_count:
            count += 1

        return subtree_sum, subtree_count

    subtree_sum_and_count(root)
    return count

Big(O) Analysis

Time Complexity
O(n²)The algorithm iterates through each node (n nodes) in the tree. For each node, it traverses its entire subtree to calculate the sum and count of nodes within that subtree. In the worst case, each node's subtree could include all the remaining nodes in the tree. Therefore, for each of the n nodes, we potentially visit up to n nodes within its subtree, leading to approximately n * n operations. Thus, the time complexity simplifies to O(n²).
Space Complexity
O(N)The algorithm iterates through each node in the tree, effectively exploring all possible subtrees. The dominant space usage comes from the recursion stack. In the worst-case scenario, such as a skewed tree, the recursion could go as deep as the number of nodes, N, creating a stack frame for each node. Therefore, the space complexity is O(N), where N is the number of nodes in the tree.

Optimal Solution

Approach

To efficiently count nodes with values equal to the average of their subtree, we'll traverse the tree in a specific order. During this traversal, we'll calculate the sum of node values and the count of nodes within each subtree, all in a single pass.

Here's how the algorithm would work step-by-step:

  1. Start at the very bottom of the tree (the leaves) and move upwards.
  2. For each node, calculate the sum of the values of all nodes in its subtree (including itself) and the total number of nodes in that subtree.
  3. To calculate these sums and counts, use the information calculated for the node's children.
  4. If the node's value is equal to the average value of its subtree (subtree sum divided by subtree node count), increment the counter.
  5. Return the calculated sum, count, and incremented counter to the parent node to continue the process upwards.
  6. Repeat these steps until you reach the root, at which point the counter will hold the total number of nodes meeting the specified condition.

Code Implementation

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def count_nodes_equal_to_average_of_subtree(root):
    number_of_nodes_equal_to_average = 0

    def postorder_traversal(node):
        nonlocal number_of_nodes_equal_to_average
        if not node:
            return 0, 0

        left_subtree_sum, left_subtree_node_count = postorder_traversal(node.left)
        right_subtree_sum, right_subtree_node_count = postorder_traversal(node.right)

        subtree_sum = node.val + left_subtree_sum + right_subtree_sum
        subtree_node_count = 1 + left_subtree_node_count + right_subtree_node_count

        # Compare current node's value with subtree average
        if node.val == subtree_sum // subtree_node_count:
            number_of_nodes_equal_to_average += 1

        return subtree_sum, subtree_node_count

    postorder_traversal(root)
    # Return the final count of nodes
    return number_of_nodes_equal_to_average

Big(O) Analysis

Time Complexity
O(n)The algorithm performs a single Depth-First Search (DFS) traversal of the binary tree. Each node in the tree is visited exactly once to calculate the subtree sum and count. Since a DFS traversal visits each of the n nodes once, the time complexity is directly proportional to the number of nodes, resulting in O(n).
Space Complexity
O(H)The algorithm employs a recursive approach to traverse the tree. In the worst-case scenario (a skewed tree), the recursion stack can grow to a depth equal to the height (H) of the tree. Each recursive call consumes space on the stack for local variables (subtree sum, subtree node count, and the counter). Therefore, the auxiliary space is proportional to the height of the tree, which is O(H). In the worst case, where the tree is a linear chain, H equals N (the number of nodes), resulting in O(N) space.

Edge Cases

CaseHow to Handle
Null root (empty tree)Return 0 if the root is null as there are no nodes to check.
Single node treeReturn 1 if the root is the only node, as the node's value is trivially equal to the subtree's average.
Tree with all identical valuesThe algorithm should correctly count all nodes as their values will always match the subtree average.
Tree with only negative valuesThe algorithm should handle negative values correctly during summation and averaging.
Tree with very large positive or negative values leading to potential integer overflow.Use a data type with a larger range, such as long, for sum calculations to prevent integer overflow.
Unbalanced/skewed treeThe recursive solution's call stack might become deep; however, given the problem's typical constraints, stack overflow is unlikely, but iterative solutions avoid this altogether.
Tree with a mix of positive, negative, and zero values.The algorithm should handle zeros correctly during summation and averaging without causing division by zero errors (ensure subtree size is never zero before averaging).
Maximum tree size causing memory constraints or slow execution.The algorithm should have a time complexity of O(N) where N is the number of nodes, making it efficient; also, avoid unnecessary memory allocations.