Given an n x n
matrix
where each of the rows and columns is sorted in ascending order, return the kth
smallest element in the matrix. Note that it is the kth
smallest element in the sorted order, not the kth
distinct element. You must find a solution with a memory complexity better than O(n^2)
. Example: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
. Expected output is 13
. Another example: matrix = [[-5]], k = 1
. Expected output is -5
.
When you get asked this question in a real-life environment, it will often be ambiguous (especially at FAANG). Make sure to ask these questions in that case:
We need to find the Kth smallest number when all the numbers are put together. The brute force method is the most straightforward: we simply look at every single number, combine them, and then find the Kth smallest.
Here's how the algorithm would work step-by-step:
def kth_smallest_element_brute_force(matrix, k_value): all_numbers = []
# Extract all numbers from the matrix.
for row in matrix:
for element in row:
all_numbers.append(element)
# Sort all the numbers in ascending order.
# This step is essential to find the kth smallest element.
all_numbers.sort()
# Return the element at the (k_value - 1)th index
# since lists are zero indexed.
return all_numbers[k_value - 1]
To efficiently find the Kth smallest element, we will use a method similar to searching in a phone book: intelligently narrowing down the possible range. We'll repeatedly guess a potential value and then quickly determine how many elements in the matrix are smaller than or equal to that guess, adjusting our guess until we pinpoint the Kth smallest.
Here's how the algorithm would work step-by-step:
def kth_smallest(matrix, k):
rows = len(matrix)
cols = len(matrix[0])
matrix_smallest_element = matrix[0][0]
matrix_largest_element = matrix[rows - 1][cols - 1]
while matrix_smallest_element < matrix_largest_element:
element_midpoint = (matrix_smallest_element + matrix_largest_element) // 2
# Count elements <= midpoint
count = 0
row = rows - 1
col = 0
while row >= 0 and col < cols:
if matrix[row][col] <= element_midpoint:
count += row + 1
col += 1
else:
row -= 1
# Adjust search range based on count.
if count < k:
matrix_smallest_element = element_midpoint + 1
# Count >= k, midpoint is too large
else:
matrix_largest_element = element_midpoint
# smallest_element is the kth smallest element.
return matrix_smallest_element
Case | How to Handle |
---|---|
Matrix is null or empty (n=0) | Return null or throw an IllegalArgumentException since there are no elements to process. |
k is less than 1 or greater than n*n | Return null or throw an IllegalArgumentException because k is out of the valid range of elements. |
Matrix is a single element matrix (n=1) | Return the single element in the matrix, as it will always be the kth smallest for k=1. |
Matrix has very large dimensions (n is large), impacting memory/time complexity | Binary search approaches are preferred for larger matrices due to their better time complexity compared to heap based solutions. |
All elements in the matrix are the same | The solution should correctly return this value since all elements are sorted and the same. |
Matrix contains negative numbers, zeros, and positive numbers | The solution should handle all numbers correctly since the rows and columns are sorted regardless of the values. |
k is 1 (smallest element) | Return the top-left element of the matrix, which is the smallest. |
k is n*n (largest element) | Return the bottom-right element of the matrix, which is the largest. |