题目
给定一个n行n列的矩阵,这个矩阵的每一行是递增有序的,求这个矩阵中第k小的元素。
解答
优解
基于二分查找和按行统计小于等于目标值的元素个数。算法的时间复杂度为,其中D是矩阵中元素值域的范围(即最大值与最小值的差),空间复杂度为
(不包括输入矩阵)。
算法描述
确定值域范围:
计算矩阵中的最小值
min_val
:取所有行首元素(即每行的第一个元素)的最小值,即min_val = min(matrix[i][0] for i in range(n))
。计算矩阵中的最大值
max_val
:取所有行尾元素(即每行的最后一个元素)的最大值,即max_val = max(matrix[i][n-1] for i in range(n))
。初始化二分查找的左边界
left = min_val
,右边界right = max_val
。
二分查找第 k 小元素:
当
left < right
时,执行循环:计算中间值
mid = (left + right) // 2
(整数除法)。统计矩阵中小于等于
mid
的元素个数count
:对于每一行 i(因为行是递增有序),使用二分查找找到该行中最后一个小于等于
mid
的元素的列索引 j(即最大的 j 满足matrix[i][j] <= mid
)。则该行中小于等于mid
的元素个数为 j+1(列索引从 0 开始)。对所有行的结果求和,得到
count
。
如果
count >= k
,则第 k 小元素小于等于mid
,设置right = mid
。否则,第 k 小元素大于
mid
,设置left = mid + 1
。
循环结束后,
left
即为第 k 小元素的值。
算法正确性
该算法通过二分查找值域,逐步缩小第 k 小元素可能存在的范围。
在每次迭代中,计算
count(mid)
(小于等于mid
的元素个数)与 k 比较,可以确定第 k 小元素位于左半区间还是右半区间。最终,
left
会收敛到矩阵中的一个实际元素,且满足是第 k 小元素(详见示例验证)。
时间复杂度分析
确定
min_val
和max_val
需要时间。
二分查找的迭代次数为
,其中D=max_val−min_val。
在每次迭代中,统计
count(mid)
需要对每一行进行一次二分查找(每行时间复杂度),因此每次迭代的时间复杂度为
。
总时间复杂度为
。
示例说明
考虑一个 2×2 矩阵:
元素集合为 {1,2,3,5},排序后为 [1,2,3,5]。
求第 k=2 小元素:
min_val = min(1, 2) = 1
,max_val = max(3, 5) = 5
,left = 1
,right = 5
。第一次迭代:
mid = (1 + 5) // 2 = 3
,count(<=3)
:第一行:元素 [1,3],1≤3,3≤3,个数为 2。
第二行:元素 [2,5],2≤3,5>3,个数为 1。
count = 2 + 1 = 3 >= 2
,设置right = 3
。
第二次迭代:
left = 1
,right = 3
,mid = (1 + 3) // 2 = 2
,count(<=2)
:第一行:1≤2,3>2,个数为 1。
第二行:2≤2,5>2,个数为 1。
count = 1 + 1 = 2 >= 2
,设置right = 2
。
第三次迭代:
left = 1
,right = 2
,mid = (1 + 2) // 2 = 1
,count(<=1)
:第一行:1≤1,3>1,个数为 1。
第二行:2>1,个数为 0。
count = 1 < 2
,设置left = 1 + 1 = 2
。
循环结束,
left = 2
,返回 2(正确,第 2 小元素是 2)。
代码实现
def kthSmallest(matrix, k):
n = len(matrix)
# 初始化二分查找的边界
left = matrix[0][0] # 最小值:矩阵左上角
right = matrix[n-1][n-1] # 最大值:矩阵右下角
# 二分查找
while left < right:
mid = (left + right) // 2
count = 0 # 统计小于等于mid的元素个数
col = n - 1 # 从每行的末尾开始检查
# 遍历每一行
for i in range(n):
# 在当前行中,从右向左找到第一个小于等于mid的元素
while col >= 0 and matrix[i][col] > mid:
col -= 1
count += (col + 1) # 该行中小于等于mid的元素个数
# 调整二分查找边界
if count >= k:
right = mid
else:
left = mid + 1
return left
# 示例测试
if __name__ == "__main__":
matrix1 = [[1, 3], [2, 5]]
print(kthSmallest(matrix1, 2)) # 输出: 2
matrix2 = [[1, 5, 9], [10, 11, 13], [12, 13, 15]]
print(kthSmallest(matrix2, 4)) # 输出: 10
matrix3 = [[-5]]
print(kthSmallest(matrix3, 1)) # 输出: -5
暴力解
忽略行递增条件,直接全排序:
def kth_smallest(matrix, k):
n = len(matrix)
# 提取所有元素到一维列表
flat_list = []
for i in range(n):
for j in range(n):
flat_list.append(matrix[i][j])
# 排序列表
flat_list.sort()
# 返回第k小的元素(k从1开始,索引为k-1)
return flat_list[k-1]