Taro Logo

Kth Smallest Element in a Sorted Matrix

Medium
Google logo
Google
2 views
Topics:
ArraysBinary Search

Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix. Note that it is the kth smallest element in the sorted order, not the kth distinct element. You must find a solution with a memory complexity better than O(n^2). Example: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8. Expected output is 13. Another example: matrix = [[-5]], k = 1. Expected output is -5.

Solution


Clarifying Questions

When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:

  1. What are the constraints on the size of the matrix n, and the value of k? Are they within reasonable limits (e.g., up to 500 or 1000) and what happens if k is larger than n*n?
  2. Can the elements in the matrix be negative, zero, or floating-point numbers?
  3. Are duplicate values allowed in the matrix? If so, how should they be handled when determining the kth smallest element?
  4. Is the matrix guaranteed to be square (n x n) or could it be rectangular (m x n) with rows and columns still sorted? If rectangular, how should the sorted property be interpreted for the return value?
  5. If the matrix is empty (n=0), or if k is zero or negative, what should be the return value?

Brute Force Solution

Approach

We need to find the Kth smallest number when all the numbers are put together. The brute force method is the most straightforward: we simply look at every single number, combine them, and then find the Kth smallest.

Here's how the algorithm would work step-by-step:

  1. Take all the numbers from the sorted matrix.
  2. Put all the numbers together in one big list.
  3. Sort all the numbers in the list from smallest to largest.
  4. Pick out the number that is at the Kth position in the list. Remember that the first element is at position 1, the second at position 2, and so on.

Code Implementation

def kth_smallest_element_brute_force(matrix, k_value):    all_numbers = []
    # Extract all numbers from the matrix.
    for row in matrix:
        for element in row:
            all_numbers.append(element)
    
    # Sort all the numbers in ascending order. 
    # This step is essential to find the kth smallest element.
    all_numbers.sort()
    
    # Return the element at the (k_value - 1)th index
    # since lists are zero indexed.
    return all_numbers[k_value - 1]

Big(O) Analysis

Time Complexity
O(n² log n)The algorithm first iterates through all n x n elements of the matrix, copying them into a list. This takes O(n²) time. Then, the list is sorted using a comparison-based sorting algorithm, such as merge sort or quicksort, which typically has a time complexity of O(n log n), where n is the number of elements being sorted. In this case, the size of the list is n², so the sorting step takes O(n² log n²) = O(n² * 2 log n) = O(n² log n) time. Selecting the Kth element after sorting takes O(1) time. Therefore, the overall time complexity is dominated by the sorting step, resulting in O(n² log n).
Space Complexity
O(N)The described solution creates a new list to store all the elements from the sorted matrix. If the sorted matrix has dimensions m x n, then N = m * n is the total number of elements. This new list requires space proportional to N. After sorting, the algorithm picks the Kth smallest element, but the storage for the sorted list remains, leading to a space complexity of O(N).

Optimal Solution

Approach

To efficiently find the Kth smallest element, we will use a method similar to searching in a phone book: intelligently narrowing down the possible range. We'll repeatedly guess a potential value and then quickly determine how many elements in the matrix are smaller than or equal to that guess, adjusting our guess until we pinpoint the Kth smallest.

Here's how the algorithm would work step-by-step:

  1. Start by considering the smallest and largest elements in the entire matrix as the initial possible range where the Kth smallest element must lie.
  2. Pick a 'middle' value within this range as our guess for the Kth smallest element.
  3. Count how many elements in the matrix are less than or equal to our guessed value. Since the matrix is sorted, we can do this efficiently by starting from the top-right corner and moving strategically.
  4. If the count is less than K, it means our guess was too small. Adjust the lower end of our range to be larger than our guess.
  5. If the count is greater than or equal to K, it means our guess was too large or possibly correct. Adjust the upper end of our range to be equal to our guess.
  6. Repeat steps 2-5, continually narrowing the range, until the lower and upper ends of the range converge. The final value is the Kth smallest element.

Code Implementation

def kth_smallest(matrix, k):
    rows = len(matrix)
    cols = len(matrix[0])

    matrix_smallest_element = matrix[0][0]
    matrix_largest_element = matrix[rows - 1][cols - 1]

    while matrix_smallest_element < matrix_largest_element:
        element_midpoint = (matrix_smallest_element + matrix_largest_element) // 2

        # Count elements <= midpoint
        count = 0
        row = rows - 1
        col = 0

        while row >= 0 and col < cols:
            if matrix[row][col] <= element_midpoint:
                count += row + 1
                col += 1
            else:
                row -= 1

        # Adjust search range based on count.
        if count < k:
            matrix_smallest_element = element_midpoint + 1

        # Count >= k, midpoint is too large
        else:
            matrix_largest_element = element_midpoint

    # smallest_element is the kth smallest element.
    return matrix_smallest_element

Big(O) Analysis

Time Complexity
O(n log(max - min))The binary search iterates while the lower bound is less than the upper bound. The number of iterations is proportional to the logarithm of the range between the maximum and minimum elements in the matrix (log(max - min)). Inside the binary search, we count elements less than or equal to the middle value. Counting involves traversing at most n rows (or columns) of the matrix in O(n) time. Thus, the overall time complexity is O(n log(max - min)), where max and min represent the maximum and minimum values in the matrix respectively, and n represents the dimension of the matrix.
Space Complexity
O(1)The algorithm primarily uses a binary search approach with a helper function to count elements less than or equal to a given value. The space complexity is dominated by the variables used for binary search (low, high, mid) and the counters within the counting function. These variables occupy a constant amount of space regardless of the matrix size (N x N, where N is the number of rows and columns). Therefore, the auxiliary space complexity is O(1).

Edge Cases

CaseHow to Handle
Matrix is null or empty (n=0)Return null or throw an IllegalArgumentException since there are no elements to process.
k is less than 1 or greater than n*nReturn null or throw an IllegalArgumentException because k is out of the valid range of elements.
Matrix is a single element matrix (n=1)Return the single element in the matrix, as it will always be the kth smallest for k=1.
Matrix has very large dimensions (n is large), impacting memory/time complexityBinary search approaches are preferred for larger matrices due to their better time complexity compared to heap based solutions.
All elements in the matrix are the sameThe solution should correctly return this value since all elements are sorted and the same.
Matrix contains negative numbers, zeros, and positive numbersThe solution should handle all numbers correctly since the rows and columns are sorted regardless of the values.
k is 1 (smallest element)Return the top-left element of the matrix, which is the smallest.
k is n*n (largest element)Return the bottom-right element of the matrix, which is the largest.