Taro Logo

Find Distance in a Binary Tree

Medium
Google logo
Google
1 view
Topics:
TreesRecursion

You are given a binary tree where each node has a unique value. You are also given the values of two nodes, node1 and node2, which are guaranteed to be present in the tree. Your task is to write a function that finds the distance between these two nodes in the tree. The distance between two nodes is defined as the number of edges on the shortest path between them.

Here are some specific requirements and edge cases to consider:

  1. The function should return the distance between node1 and node2. If either node is not found in the tree, return -1.
  2. If node1 and node2 are the same, the distance should be 0.
  3. The tree can be of any size, from empty to very large. Be mindful of potential stack overflow issues with recursive solutions for very large trees.

For example, consider the following binary tree:

      5
     / \
    3   6
   / \   
  2   4  
 /       
1         
  • find_distance(root, 2, 6) should return 3 (2 -> 3 -> 5 -> 6).
  • find_distance(root, 1, 4) should return 4 (1 -> 2 -> 3 -> 4).
  • find_distance(root, 3, 6) should return 2 (3 -> 5 -> 6).
  • find_distance(root, 5, 5) should return 0.

Explain your approach, analyze its time and space complexity, and provide code implementing your solution. How does your solution handle edge cases such as empty trees or nodes not found?

Solution


Clarifying Questions

When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:

  1. Can I assume the values 'n1' and 'n2' are guaranteed to exist in the tree?
  2. What should be returned if 'n1' or 'n2' is the root node, or if 'n1' and 'n2' are the same node?
  3. Are the node values unique within the binary tree?
  4. Is the given tree a Binary Search Tree (BST) or a general binary tree?
  5. What is the expected return value if either 'n1' or 'n2' (or both) are null/None?

Brute Force Solution

Approach

The brute force way to find the distance between two people in a family tree is to look at every possible path between them. We will explore all routes, like taking every possible road in a city to find the quickest way to get somewhere. This means considering all branches in the tree and checking how far apart the two people are along each route.

Here's how the algorithm would work step-by-step:

  1. Start at the first person in the family tree.
  2. List all the people they are directly connected to (their children or parents).
  3. For each of those connected people, list all the people they are connected to (their children, parents, or siblings).
  4. Keep doing this, exploring every possible connection, until you find the second person you are looking for.
  5. Each time you find the second person, keep track of how many connections it took to get there. This is the distance for that particular path.
  6. Repeat the entire process, starting from the second person, and searching for the first.
  7. After exploring every possible path, compare all the distances you found.
  8. The smallest distance you found is the distance between the two people in the family tree.

Code Implementation

class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

def find_distance_brute_force(root, node1_value, node2_value):

    def find_all_paths(start_node, target_value, current_path, all_paths):
        if start_node is None:
            return

        current_path.append(start_node.value)

        if start_node.value == target_value:
            all_paths.append(list(current_path))

        find_all_paths(start_node.left, target_value, current_path, all_paths)
        find_all_paths(start_node.right, target_value, current_path, all_paths)

        current_path.pop()

    all_paths_node1 = []
    find_all_paths(root, node1_value, [], all_paths_node1)

    all_paths_node2 = []
    find_all_paths(root, node2_value, [], all_paths_node2)

    min_distance = float('inf')

    # Iterate through all possible paths between the two nodes
    for path1 in all_paths_node1:
        for path2 in all_paths_node2:
            # We are using sets here to prevent duplicates 
            set1 = set(path1)
            set2 = set(path2)
            union_length = len(set1.union(set2))
            intersection_length = len(set1.intersection(set2))

            # Calculating distance using formula based on set operations
            distance = union_length - intersection_length

            min_distance = min(min_distance, distance)

    if min_distance == float('inf'):
        return -1
    else:
        return min_distance

Big(O) Analysis

Time Complexity
O(n!)The described brute force approach essentially explores all possible paths in the binary tree. In the worst-case scenario, it's akin to generating all permutations of nodes, where n is the number of nodes in the tree. Exploring every permutation to find the shortest distance between the two target nodes results in a factorial time complexity because the number of possible paths explodes as the tree grows. Therefore, the algorithm has a time complexity of O(n!).
Space Complexity
O(N)The algorithm uses a form of breadth-first search (BFS) implicitly by exploring all possible connections. In the worst-case scenario, where the tree is very wide, we might need to keep track of nearly all the nodes in the tree to explore their connections, resembling a level-order traversal. This 'keeping track' would require a queue-like structure which, in the worst case, stores a significant portion of the nodes. Therefore, the auxiliary space needed grows linearly with the number of nodes N in the tree, resulting in O(N) space complexity.

Optimal Solution

Approach

To find the distance between two nodes in a binary tree efficiently, we'll first find the lowest common ancestor of those two nodes. Then, we'll calculate the distance from the ancestor to each of the nodes and sum those distances.

Here's how the algorithm would work step-by-step:

  1. First, find the lowest common ancestor of the two given nodes. The lowest common ancestor is the node that is an ancestor of both nodes but is farthest from the root.
  2. Next, find the distance from the lowest common ancestor to the first node.
  3. Then, find the distance from the lowest common ancestor to the second node.
  4. Finally, add the two distances together. This sum is the distance between the two nodes in the tree.

Code Implementation

class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None

def find_distance_between_nodes(
        root_node, first_node, second_node):

    lowest_common_ancestor = find_lowest_common_ancestor(
        root_node, first_node, second_node)

    distance_to_first_node = find_level(
        lowest_common_ancestor, first_node, 0)

    distance_to_second_node = find_level(
        lowest_common_ancestor, second_node, 0)

    return distance_to_first_node + distance_to_second_node

def find_lowest_common_ancestor(
        root_node, first_node, second_node):
    # Base case, empty tree
    if root_node is None:
        return None

    if (root_node.data == first_node or
            root_node.data == second_node):
        return root_node

    left_lowest_common_ancestor = find_lowest_common_ancestor(
        root_node.left, first_node, second_node)

    right_lowest_common_ancestor = find_lowest_common_ancestor(
        root_node.right, first_node, second_node)

    if (left_lowest_common_ancestor is not None and
            right_lowest_common_ancestor is not None):
        # If both left and right calls return non-None,
        # then this node is the LCA
        return root_node

    if left_lowest_common_ancestor is not None:
        return left_lowest_common_ancestor
    else:
        return right_lowest_common_ancestor

def find_level(root_node, target_node, level):
    if root_node is None:
        return -1

    if root_node.data == target_node:
        return level

    downlevel = find_level(root_node.left, target_node, level + 1)

    if downlevel == -1:
        downlevel = find_level(root_node.right, target_node, level + 1)

    return downlevel

Big(O) Analysis

Time Complexity
O(n)Finding the Lowest Common Ancestor (LCA) in a binary tree requires traversing the tree, which takes O(n) time in the worst case, where n is the number of nodes in the tree. Finding the distance from the LCA to each of the two nodes also requires traversing the tree from the LCA downwards, again taking O(n) time in the worst case. Since we perform these traversals sequentially, the overall time complexity is O(n) + O(n) + O(n), which simplifies to O(n).
Space Complexity
O(H)The primary space complexity arises from the recursive calls made while finding the Lowest Common Ancestor (LCA) and calculating the distances. In the worst-case scenario, the recursion stack can grow to the height (H) of the binary tree, where H can be N in a skewed tree. Therefore, the auxiliary space used by the recursion stack is proportional to the height of the tree. This space includes storing function call contexts and local variables for each recursive call during the LCA search and distance computations.

Edge Cases

CaseHow to Handle
Root is nullReturn -1 immediately as there is no tree to traverse.
Either n1 or n2 is nullReturn -1 immediately since either node cannot exist within the tree.
n1 or n2 is not present in the treeThe LCA algorithm should detect absence of nodes and return -1 accordingly.
n1 and n2 are the same nodeDistance is 0; if LCA returns this node calculate distance from root to that node * 2.
Large, skewed tree could lead to stack overflow with recursive solutionsConsider an iterative solution using a stack for traversal to avoid stack overflow.
Integer overflow during distance calculation in very deep treesUse a larger data type like long to store distances.
One node is an ancestor of the otherCalculate distance from LCA to the farther node directly.
Tree with only one nodeIf n1 and n2 are this node, distance is 0; otherwise return -1.