Taro Logo

Sum of Distances in Tree

Hard
Microsoft logo
Microsoft
1 view
Topics:
TreesGraphsDynamic Programming

There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.

Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.

Example 1:

Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.

Example 2:

Input: n = 1, edges = []
Output: [0]

Example 3:

Input: n = 2, edges = [[1,0]]
Output: [1,1]

Constraints:

  • 1 <= n <= 3 * 104
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • The given input represents a valid tree.

Solution


Clarifying Questions

When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:

  1. What is the range of values for 'n', the number of nodes? Are there any specific memory constraints?
  2. Are the edges in the input guaranteed to represent a valid tree (connected, no cycles)?
  3. Are the node indices 0-based or 1-based?
  4. If n is 0, should I return an empty array, or is that an invalid input?
  5. Is the graph directed or undirected?

Brute Force Solution

Approach

The brute force method calculates the sum of distances from each node to all other nodes in a tree. For each node, we visit every other node in the tree and calculate the distance between them. We then sum all these distances to find the total distance for that node.

Here's how the algorithm would work step-by-step:

  1. For every single node in the tree, consider it the starting point.
  2. From that starting node, find the shortest path to every other node in the tree.
  3. Add up all those shortest path lengths you just found for that starting node. This is the total distance from that node to all others.
  4. Repeat the previous two steps for every single node in the tree.
  5. The final result is a list of total distances, one for each node in the tree, showing how far each node is from all the others combined.

Code Implementation

def sum_of_distances_in_tree_brute_force(number_of_nodes, edges):
    adjacency_list = [[] for _ in range(number_of_nodes)]
    for start_node, end_node in edges:
        adjacency_list[start_node].append(end_node)
        adjacency_list[end_node].append(start_node)

    result = []
    for start_node in range(number_of_nodes):
        total_distance = 0

        # Iterate through all other nodes to calculate total distances
        for destination_node in range(number_of_nodes):
            if start_node == destination_node:
                continue

            queue = [(start_node, 0)]
            visited = {start_node}
            shortest_distance = float('inf')

            while queue:
                current_node, distance = queue.pop(0)
                if current_node == destination_node:
                    shortest_distance = distance
                    break

                # Standard BFS traversal, exploring all available paths
                for neighbor in adjacency_list[current_node]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append((neighbor, distance + 1))

            total_distance += shortest_distance

        result.append(total_distance)

    return result

Big(O) Analysis

Time Complexity
O(n²)The algorithm iterates through each of the n nodes in the tree. For each node, it calculates the distance to every other node, requiring a traversal that, in the worst case, visits all other n-1 nodes. This results in roughly n operations for each of the n nodes. Therefore, the total number of operations approximates n * n, which simplifies to O(n²).
Space Complexity
O(N)The brute force method, for each node, finds the shortest path to every other node in the tree. This implies the need to store information about these shortest paths, such as a temporary data structure to track visited nodes during the path finding process, potentially using a breadth-first search or depth-first search approach. In the worst case, these data structures could store information related to all nodes in the tree. Therefore, the auxiliary space used is proportional to the number of nodes, N, resulting in a space complexity of O(N).

Optimal Solution

Approach

The challenge is to find the total distance from each location in a network to all other locations. Instead of individually calculating all pairwise distances, we utilize a clever, more efficient strategy: First calculate some initial information, then cleverly reuse and update that information to derive remaining distances.

Here's how the algorithm would work step-by-step:

  1. First, pick any location as your starting point and figure out the total distance from that location to every other location, as well as the number of locations in the whole network.
  2. Next, for each neighbor of our starting point, calculate the distance from that neighbor to all the other locations. Here's the smart part: we can use the total distance we figured out in the first step to help us avoid recalculating everything from scratch.
  3. When you move from the initial location to one of its neighbors, some locations get closer, and some get farther away. The total distance changes according to the number of locations that are closer or farther.
  4. Figure out how many locations are on the side of the neighbor and how many are on the side of the starting location.
  5. Using these counts, update the total distance for the neighbor without having to recompute all individual distances again.
  6. Continue this process for all locations in the network. You are essentially moving the 'focus' of the calculation from one location to its neighbors, updating the total distance with each move, until all locations have been covered.
  7. By cleverly reusing earlier calculations, you get the total distance for each location without having to do a lot of redundant work.

Code Implementation

def sum_of_distances_in_tree(number_of_nodes: int, edges: list[list[int]]) -> list[int]:
    graph = [[] for _ in range(number_of_nodes)]
    for start_node, end_node in edges:
        graph[start_node].append(end_node)
        graph[end_node].append(start_node)

    node_count = [1] * number_of_nodes
    total_distances = [0] * number_of_nodes

    def depth_first_search(current_node: int, previous_node: int) -> None:
        # Calculate the size of each subtree and initial total distances
        for neighbor_node in graph[current_node]:
            if neighbor_node != previous_node:
                depth_first_search(neighbor_node, current_node)
                node_count[current_node] += node_count[neighbor_node]
                total_distances[current_node] += total_distances[neighbor_node] + node_count[neighbor_node]

    depth_first_search(0, -1)

    def depth_first_search_two(current_node: int, previous_node: int) -> None:
        # Distribute total distances using subtree information
        for neighbor_node in graph[current_node]:
            if neighbor_node != previous_node:
                total_distances[neighbor_node] = total_distances[current_node] - node_count[neighbor_node] + (number_of_nodes - node_count[neighbor_node])
                depth_first_search_two(neighbor_node, current_node)

    depth_first_search_two(0, -1)

    return total_distances

Big(O) Analysis

Time Complexity
O(n)The algorithm performs a depth-first search (DFS) or similar traversal of the tree. The first step calculates the sum of distances from an arbitrary root, which takes O(n) time as it visits each node. The subsequent steps iteratively update the distances for each node based on its parent's distance, also visiting each node once. Thus, each node and edge is visited a constant number of times. Since the tree has n nodes and approximately n-1 edges, the time complexity is dominated by a single pass through the tree.
Space Complexity
O(N)The algorithm implicitly uses an adjacency list or similar data structure to represent the tree's edges, where N is the number of nodes. Additionally, arrays or lists of size N are likely used to store the total distances from each node and the number of nodes in each subtree during the traversal process. These data structures contribute to a space complexity that grows linearly with the number of nodes, N. Therefore, the auxiliary space complexity is O(N).

Edge Cases

CaseHow to Handle
n = 1, no edgesReturn [0] as the distance from a single node to itself is 0.
Empty edge list (n > 1, but disconnected graph)The problem statement implies a connected tree, but handle a disconnected graph by returning a list of -1s, or 0s, indicating no connection or unreachable node.
Large n (e.g., n = 10000)Ensure the chosen algorithm (e.g., using DFS or BFS) and data structures (e.g., adjacency lists) scale efficiently to avoid exceeding time limits, also consider potential stack overflow with naive recursion for large trees.
Tree is a star graph (one node connected to all others)The central node's distance sum will be n-1, and peripheral nodes will have (n-2)+1.
Tree is a linear chain (all nodes connected in a line)Calculate distances iteratively, noting the arithmetic progression of sums.
Integer overflow when calculating sums for large treesUse 64-bit integers (long in Java/C++) to store intermediate and final distance sums.
Negative node indices are not allowed as problem statement says 'n representing the number of nodes in a tree'Validate the input and return an appropriate error if negative indices are present.
Cycles in the input 'edges' listThe problem specifies a tree, so handle a cycle (e.g., by detecting it during graph construction and returning an error) since standard tree algorithms will not work.