Given an n x n
matrix
where each of the rows and columns is sorted in ascending order, return the kth
smallest element in the matrix. Note that it is the kth
smallest element in the sorted order, not the kth
distinct element. You must find a solution with a memory complexity better than O(n^2)
. Example: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
. Expected output is 13
. Another example: matrix = [[-5]], k = 1
. Expected output is -5
.
Given an n x n
matrix where each of the rows and columns is sorted in ascending order, the task is to find the k-th smallest element in the matrix.
Note that it is the k-th smallest element in the sorted order, not the k-th distinct element.
matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
The simplest approach is to extract all elements from the matrix into a single array, sort the array, and then return the element at index k-1
. This solution is straightforward but not the most efficient in terms of memory complexity.
def kth_smallest_brute_force(matrix, k):
arr = []
for row in matrix:
arr.extend(row)
arr.sort()
return arr[k-1]
O(n^2 log(n^2)) because we are sorting an array of n^2 elements.
O(n^2) to store all the elements in a single array.
Since the matrix is sorted in both rows and columns, we can use binary search to find the k-th smallest element. The idea is to find a middle value and count how many elements in the matrix are less than or equal to that middle value. Adjust the search range based on this count.
left
to the smallest element in the matrix (matrix[0][0]) and right
to the largest element (matrix[n-1][n-1]).left < right
:
mid = left + (right - left) // 2
.mid
.k
, it means the k-th smallest element is greater than mid
. So, update left = mid + 1
.mid
. Update right = mid
.left
(or right
), which will be the k-th smallest element.def kth_smallest_optimal(matrix, k):
n = len(matrix)
left = matrix[0][0]
right = matrix[n-1][n-1]
while left < right:
mid = left + (right - left) // 2
count = 0
j = n - 1
for i in range(n):
while j >= 0 and matrix[i][j] > mid:
j -= 1
count += (j + 1)
if count < k:
left = mid + 1
else:
right = mid
return left
O(n log(X)), where n is the dimension of the matrix and X is the range of possible values (right - left). For each mid
value, we iterate through at most n
rows, resulting in O(n) work per iteration. The binary search performs log(X) iterations, where X = matrix[n-1][n-1] - matrix[0][0]
.
O(1) because the algorithm uses a constant amount of extra space.
1 <= n <= 300
, so an empty matrix is not a valid input.