There is an undirected connected tree with n
nodes labeled from 0
to n - 1
and n - 1
edges.
You are given the integer n
and the array edges
where edges[i] = [ai, bi]
indicates that there is an edge between nodes ai
and bi
in the tree.
Return an array answer
of length n
where answer[i]
is the sum of the distances between the ith
node in the tree and all other nodes.
Example 1:
Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]] Output: [8,12,6,10,10,10] Explanation: The tree is shown above. We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) equals 1 + 1 + 2 + 2 + 2 = 8. Hence, answer[0] = 8, and so on.
Example 2:
Input: n = 1, edges = [] Output: [0]
Example 3:
Input: n = 2, edges = [[1,0]] Output: [1,1]
Constraints:
1 <= n <= 3 * 104
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
ai != bi
When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:
The brute force method calculates the sum of distances from each node to all other nodes in a tree. For each node, we visit every other node in the tree and calculate the distance between them. We then sum all these distances to find the total distance for that node.
Here's how the algorithm would work step-by-step:
def sum_of_distances_in_tree_brute_force(number_of_nodes, edges):
adjacency_list = [[] for _ in range(number_of_nodes)]
for start_node, end_node in edges:
adjacency_list[start_node].append(end_node)
adjacency_list[end_node].append(start_node)
result = []
for start_node in range(number_of_nodes):
total_distance = 0
# Iterate through all other nodes to calculate total distances
for destination_node in range(number_of_nodes):
if start_node == destination_node:
continue
queue = [(start_node, 0)]
visited = {start_node}
shortest_distance = float('inf')
while queue:
current_node, distance = queue.pop(0)
if current_node == destination_node:
shortest_distance = distance
break
# Standard BFS traversal, exploring all available paths
for neighbor in adjacency_list[current_node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, distance + 1))
total_distance += shortest_distance
result.append(total_distance)
return result
The challenge is to find the total distance from each location in a network to all other locations. Instead of individually calculating all pairwise distances, we utilize a clever, more efficient strategy: First calculate some initial information, then cleverly reuse and update that information to derive remaining distances.
Here's how the algorithm would work step-by-step:
def sum_of_distances_in_tree(number_of_nodes: int, edges: list[list[int]]) -> list[int]:
graph = [[] for _ in range(number_of_nodes)]
for start_node, end_node in edges:
graph[start_node].append(end_node)
graph[end_node].append(start_node)
node_count = [1] * number_of_nodes
total_distances = [0] * number_of_nodes
def depth_first_search(current_node: int, previous_node: int) -> None:
# Calculate the size of each subtree and initial total distances
for neighbor_node in graph[current_node]:
if neighbor_node != previous_node:
depth_first_search(neighbor_node, current_node)
node_count[current_node] += node_count[neighbor_node]
total_distances[current_node] += total_distances[neighbor_node] + node_count[neighbor_node]
depth_first_search(0, -1)
def depth_first_search_two(current_node: int, previous_node: int) -> None:
# Distribute total distances using subtree information
for neighbor_node in graph[current_node]:
if neighbor_node != previous_node:
total_distances[neighbor_node] = total_distances[current_node] - node_count[neighbor_node] + (number_of_nodes - node_count[neighbor_node])
depth_first_search_two(neighbor_node, current_node)
depth_first_search_two(0, -1)
return total_distances
Case | How to Handle |
---|---|
n = 1, no edges | Return [0] as the distance from a single node to itself is 0. |
Empty edge list (n > 1, but disconnected graph) | The problem statement implies a connected tree, but handle a disconnected graph by returning a list of -1s, or 0s, indicating no connection or unreachable node. |
Large n (e.g., n = 10000) | Ensure the chosen algorithm (e.g., using DFS or BFS) and data structures (e.g., adjacency lists) scale efficiently to avoid exceeding time limits, also consider potential stack overflow with naive recursion for large trees. |
Tree is a star graph (one node connected to all others) | The central node's distance sum will be n-1, and peripheral nodes will have (n-2)+1. |
Tree is a linear chain (all nodes connected in a line) | Calculate distances iteratively, noting the arithmetic progression of sums. |
Integer overflow when calculating sums for large trees | Use 64-bit integers (long in Java/C++) to store intermediate and final distance sums. |
Negative node indices are not allowed as problem statement says 'n representing the number of nodes in a tree' | Validate the input and return an appropriate error if negative indices are present. |
Cycles in the input 'edges' list | The problem specifies a tree, so handle a cycle (e.g., by detecting it during graph construction and returning an error) since standard tree algorithms will not work. |