You are given the root of a binary tree. Each node in the tree has a unique value. You are also given the values of two different nodes, node1
and node2
. Your task is to write a function to find the distance between these two nodes in the binary tree. The distance between two nodes is defined as the number of edges on the path from node1
to node2
.
For example, consider the following binary tree:
3
/ \
5 1
/ \ / \
6 2 0 8
/ \
7 4
Your function should take the root of the tree and the values of the two nodes as input and return the distance between them. Focus on an efficient solution and consider edge cases such as when one or both nodes are not present in the tree.
When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:
The brute force way to find the distance between two people in a family tree is to look at every possible path between them. We will explore all routes, like taking every possible road in a city to find the quickest way to get somewhere. This means considering all branches in the tree and checking how far apart the two people are along each route.
Here's how the algorithm would work step-by-step:
class TreeNode:
def __init__(self, value):
self.value = value
self.left = None
self.right = None
def find_distance_brute_force(root, node1_value, node2_value):
def find_all_paths(start_node, target_value, current_path, all_paths):
if start_node is None:
return
current_path.append(start_node.value)
if start_node.value == target_value:
all_paths.append(list(current_path))
find_all_paths(start_node.left, target_value, current_path, all_paths)
find_all_paths(start_node.right, target_value, current_path, all_paths)
current_path.pop()
all_paths_node1 = []
find_all_paths(root, node1_value, [], all_paths_node1)
all_paths_node2 = []
find_all_paths(root, node2_value, [], all_paths_node2)
min_distance = float('inf')
# Iterate through all possible paths between the two nodes
for path1 in all_paths_node1:
for path2 in all_paths_node2:
# We are using sets here to prevent duplicates
set1 = set(path1)
set2 = set(path2)
union_length = len(set1.union(set2))
intersection_length = len(set1.intersection(set2))
# Calculating distance using formula based on set operations
distance = union_length - intersection_length
min_distance = min(min_distance, distance)
if min_distance == float('inf'):
return -1
else:
return min_distance
To find the distance between two nodes in a binary tree efficiently, we'll first find the lowest common ancestor of those two nodes. Then, we'll calculate the distance from the ancestor to each of the nodes and sum those distances.
Here's how the algorithm would work step-by-step:
class Node:
def __init__(self, data):
self.data = data
self.left = None
self.right = None
def find_distance_between_nodes(
root_node, first_node, second_node):
lowest_common_ancestor = find_lowest_common_ancestor(
root_node, first_node, second_node)
distance_to_first_node = find_level(
lowest_common_ancestor, first_node, 0)
distance_to_second_node = find_level(
lowest_common_ancestor, second_node, 0)
return distance_to_first_node + distance_to_second_node
def find_lowest_common_ancestor(
root_node, first_node, second_node):
# Base case, empty tree
if root_node is None:
return None
if (root_node.data == first_node or
root_node.data == second_node):
return root_node
left_lowest_common_ancestor = find_lowest_common_ancestor(
root_node.left, first_node, second_node)
right_lowest_common_ancestor = find_lowest_common_ancestor(
root_node.right, first_node, second_node)
if (left_lowest_common_ancestor is not None and
right_lowest_common_ancestor is not None):
# If both left and right calls return non-None,
# then this node is the LCA
return root_node
if left_lowest_common_ancestor is not None:
return left_lowest_common_ancestor
else:
return right_lowest_common_ancestor
def find_level(root_node, target_node, level):
if root_node is None:
return -1
if root_node.data == target_node:
return level
downlevel = find_level(root_node.left, target_node, level + 1)
if downlevel == -1:
downlevel = find_level(root_node.right, target_node, level + 1)
return downlevel
Case | How to Handle |
---|---|
Root is null | Return -1 immediately as there is no tree to traverse. |
Either n1 or n2 is null | Return -1 immediately since either node cannot exist within the tree. |
n1 or n2 is not present in the tree | The LCA algorithm should detect absence of nodes and return -1 accordingly. |
n1 and n2 are the same node | Distance is 0; if LCA returns this node calculate distance from root to that node * 2. |
Large, skewed tree could lead to stack overflow with recursive solutions | Consider an iterative solution using a stack for traversal to avoid stack overflow. |
Integer overflow during distance calculation in very deep trees | Use a larger data type like long to store distances. |
One node is an ancestor of the other | Calculate distance from LCA to the farther node directly. |
Tree with only one node | If n1 and n2 are this node, distance is 0; otherwise return -1. |