Given an n x n
matrix
where each of the rows and columns is sorted in ascending order, return the k<sup>th</sup>
smallest element in the matrix. Note that it is the k<sup>th</sup>
smallest element in the sorted order, not the k<sup>th</sup>
distinct element. You must find a solution with a memory complexity better than O(n<sup>2</sup>)
. For example, given matrix = [[1,5,9],[10,11,13],[12,13,15]]
, and k = 8
, the expected output is 13
because the sorted elements are [1, 5, 9, 10, 11, 12, 13, 13, 15]
and the 8th smallest is 13. As another example, given matrix = [[-5]]
, and k = 1
, the expected output is -5
.
## Kth Smallest Element in a Sorted Matrix
This problem asks us to find the kth smallest element in an n x n matrix where each row and column is sorted in ascending order. The solution should have a memory complexity better than O(n^2).
### 1. Brute Force Solution
The most straightforward approach is to flatten the matrix into a single sorted array and then return the element at index k-1. This will give us the correct result, but it will not meet the memory complexity requirement.
```python
def kth_smallest_brute_force(matrix, k):
flattened = []
for row in matrix:
flattened.extend(row)
flattened.sort()
return flattened[k-1]
We can use binary search to find the kth smallest element. The idea is to find a middle value between the smallest and largest elements in the matrix. Then, we count how many elements in the matrix are less than or equal to the middle value. If the count is less than k, we need to search in the higher range. Otherwise, we search in the lower range.
def kth_smallest(matrix, k):
n = len(matrix)
low = matrix[0][0]
high = matrix[n-1][n-1]
while low < high:
mid = low + (high - low) // 2
count = 0
j = n - 1
for i in range(n):
while j >= 0 and matrix[i][j] > mid:
j -= 1
count += (j + 1)
if count < k:
low = mid + 1
else:
high = mid
return low
The time complexity of the binary search solution is O(n log(range)), where n is the number of rows (or columns) in the matrix, and range is the difference between the largest and smallest elements in the matrix. This is because the binary search iterates log(range) times, and in each iteration, we traverse each row of the matrix, which takes O(n) time.
The brute force solution has a time complexity of O(n^2 log(n^2)) because we must first flatten the matrix, which has n^2 elements, and then sort it using an algorithm such as merge sort.
The space complexity of the binary search solution is O(1) because we are only using a few variables to store the low, high, and mid values, and the count. We do not allocate any auxiliary data structures that scale with the size of the input.
The brute force method has a space complexity of O(n^2) because we are creating a new array with n^2 elements to store the flattened matrix.
Here's how edge cases are handled in the provided kth_smallest
function:
The while loop while low < high
naturally exits when low
equals high
. This represents the case when the search range has narrowed down to a single value, which is the kth
smallest element. This also implicitly handles the single-element matrix case because low
and high
are initialized to the element's value, and the loop does not execute.
The algorithm handles duplicate elements correctly because it counts the number of elements less than or equal to mid
. It does not make any assumptions about the uniqueness of matrix elements.
If k
is out of range (e.g., larger than n*n
), the function may still produce a result, but it may not be meaningful. It's better to validate k
before the binary search.
def kth_smallest_with_edge_cases(matrix, k):
n = len(matrix)
if n == 0:
raise ValueError("Matrix cannot be empty")
if k < 1 or k > n * n:
raise ValueError("k is out of range")
low = matrix[0][0]
high = matrix[n-1][n-1]
while low < high:
mid = low + (high - low) // 2
count = 0
j = n - 1
for i in range(n):
while j >= 0 and matrix[i][j] > mid:
j -= 1
count += (j + 1)
if count < k:
low = mid + 1
else:
high = mid
return low