文章目录
PyTorch 张量核心操作——比较、排序与数据校验
在深度学习开发中,张量(Tensor)的比较、排序和数据校验是基础且高频的操作。无论是模型训练中的数据预处理,还是推理阶段的结果分析,这些操作都扮演着重要角色。本文将系统讲解 PyTorch 中与张量相关的比较运算、排序方法、Top-K 选取、K-th 值提取以及数据合法性校验,涵盖函数原型、参数详解、代码示例和结果分析,帮助初学者全面掌握这些核心技能。
一、张量比较运算:判断元素间的关系
比较运算用于判断张量元素之间的大小关系或相等性,返回与输入张量形状相同的布尔型张量(True
/False
)。PyTorch 提供了丰富的比较运算符和函数,支持元素级别的逐元素比较。
1. 基础比较运算符
PyTorch 支持与 Python 类似的比较运算符,包括 ==
(等于)、!=
(不等于)、>
(大于)、<
(小于)、>=
(大于等于)、<=
(小于等于)。这些运算符均为元素级运算,即对两个张量的对应元素逐一进行比较。
运算符特点:
- 要求两个张量形状相同或可广播(广播机制见后文补充)。
- 返回布尔型张量,
True
表示满足条件,False
表示不满足。
运算原理:
张量的比较运算遵循 “位置对应” 原则 :两个张量必须形状相同(或可通过广播机制扩展为相同形状),然后对相同位置的元素逐一进行比较,最终生成一个形状相同的布尔张量(True
/False
)。
- “大小” 的含义:张量中元素的大小就是其数值大小(如
3 > 2
、-1 < 0
等),与元素在张量中的位置无关。 - 布尔张量的意义:结果中
True
表示对应位置的元素满足比较条件,False
表示不满足。
运算演示:
示例:比较两个一维张量 a = [1, 3, 5, 7]
和 b = [2, 3, 4, 8]
步骤 1:确认张量形状
a
的形状:(4,)
(1 维,4 个元素)b
的形状:(4,)
(1 维,4 个元素)
形状相同,可直接比较(无需广播)。
步骤 2:逐元素比较(以
a > b
为例)比较逻辑:对每个索引
i
(0 ≤ i < 4),判断a[i] > b[i]
是否成立。i=0:
a[0] = 1
,b[0] = 2
→1 > 2
?→ Falsei=1:
a[1] = 3
,b[1] = 3
→3 > 3
?→ Falsei=2:
a[2] = 5
,b[2] = 4
→5 > 4
?→ Truei=3:
a[3] = 7
,b[3] = 8
→7 > 8
?→ False
步骤 3:生成结果张量
- 将上述判断结果按原位置组合,得到布尔张量:
a > b
的结果为[False, False, True, False]
- 将上述判断结果按原位置组合,得到布尔张量:
代码示例:
import torch
# 定义两个形状相同的张量
a = torch.tensor([1, 3, 5, 7])
b = torch.tensor([2, 3, 4, 8])
# 比较运算
print("a == b:", a == b) # 等于
print("a != b:", a != b) # 不等于
print("a > b: ", a > b) # 大于
print("a < b: ", a < b) # 小于
print("a >= b:", a >= b) # 大于等于
print("a <= b:", a <= b) # 小于等于
运行结果:
a == b: tensor([False, True, False, False])
a != b: tensor([ True, False, True, True])
a > b: tensor([False, False, True, False])
a < b: tensor([ True, False, False, True])
a >= b: tensor([False, True, True, False])
a <= b: tensor([ True, True, False, True])
结果分析:
- 每个运算符都对
a
和b
的对应元素进行比较(如a[0]=1
与b[0]=2
比较,1 < 2
故a < b
的第 0 位为True
)。 - 布尔张量的形状与输入张量一致(均为
(4,)
),便于后续基于条件筛选元素(如a[a > b]
可提取a
中大于b
对应元素的值)。
2. 比较函数:torch.eq()
与 torch.ne()
等
除运算符外,PyTorch 还提供了对应的函数形式,如 torch.eq()
(等于)、torch.ne()
(不等于)、torch.gt()
(大于)、torch.lt()
(小于)、torch.ge()
(大于等于)、torch.le()
(小于等于)。这些函数与运算符功能一致,但支持更灵活的参数设置(如广播)。
函数原型:
torch.eq(input, other, *, out=None) → Tensor
torch.ne(input, other, *, out=None) → Tensor
torch.gt(input, other, *, out=None) → Tensor
# 其余函数参数类似
参数说明:
input
:输入张量(第一个比较对象)。other
:第二个比较对象(可以是张量或标量)。out
(可选):输出张量,用于存储结果(需与预期输出形状一致)。
代码示例(支持广播机制):
# 形状不同但可广播的张量比较
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状 (2, 3)
b = torch.tensor([3]) # 标量张量(形状 ()),可广播为 (2, 3)
# 使用函数进行比较(等价于 a < 3)
lt_result = torch.lt(a, b)
print("a < 3 的结果:\n", lt_result)
运行结果:
a < 3 的结果:
tensor([[ True, True, False],
[False, False, False]])
结果分析:
b
是标量张量,通过广播机制自动扩展为与a
同形状的张量[[3, 3, 3], [3, 3, 3]]
。torch.lt(a, b)
逐元素比较a
和扩展后的b
,返回布尔张量(a
中元素小于 3 的位置为True
)。
比较运算的核心应用:
- 条件筛选:通过
tensor[布尔张量]
提取满足条件的元素(如a[a > 5]
提取a
中大于 5 的元素)。 - 掩码操作:生成掩码张量用于数据过滤或加权计算。
- 结果验证:在模型推理中判断预测结果与标签的匹配情况(如计算准确率时统计
pred == label
的数量)。
3. 张量相等性判断 torch.equal()
在 PyTorch 中,torch.equal()
是一个用于判断两个张量是否完全相等的函数。它与我们之前讲过的元素级比较运算符(如 ==
)不同,后者返回一个布尔张量,而 torch.equal()
会返回一个单一的布尔值(True
或 False
),表示两个张量是否在所有元素和形状上都完全一致。
3.1 torch.equal()
函数原型
torch.equal(input1, input2) → bool
参数说明:
input1
:第一个待比较的张量。input2
:第二个待比较的张量。
返回值:
- 布尔值(
True
或False
):如果两个张量的形状相同且所有对应元素都相等,则返回True
;否则返回False
。
3.2 核心特点:判断“整体相等性”
torch.equal()
的核心是整体判断,而非元素级判断。它有两个严格条件:
- 两个张量的形状必须完全相同;
- 两个张量所有对应位置的元素必须完全相等。
只有同时满足这两个条件,才会返回 True
。
3.3 代码示例与结果分析
示例 1:形状相同且元素全相等
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
print(torch.equal(a, b)) # 输出:True
分析:a
和 b
形状均为 (3,)
,且所有元素对应相等,因此返回 True
。
示例 2:形状相同但元素不全相等
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
print(torch.equal(a, b)) # 输出:False
分析:虽然形状相同,但 a[1]=2
与 b[1]=4
不相等,因此返回 False
。
示例 3:形状不同(即使元素“看起来”对应)
a = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
b = torch.tensor([1, 2, 3, 4]) # 形状 (4,)
print(torch.equal(a, b)) # 输出:False
分析:a
是 2×2 的二维张量,b
是长度为 4 的一维张量,形状不同,直接返回 False
。
示例 4:浮点数的相等判断(需注意精度)
a = torch.tensor([1.0, 2.0])
b = torch.tensor([1.0 + 1e-9, 2.0]) # 第一个元素有微小差异
print(torch.equal(a, b)) # 输出:False
分析:浮点数由于精度问题,即使差异极小(如 1e-9
),torch.equal()
也会判定为不相等。如果需要忽略微小误差,应使用 torch.allclose()
(后续会介绍)。
示例 5:包含 NaN
的张量(特殊情况)
a = torch.tensor([1.0, torch.nan])
b = torch.tensor([1.0, torch.nan])
print(torch.equal(a, b)) # 输出:False
分析:由于 NaN
与任何值(包括自身)都不相等(数学定义),因此即使两个张量都含 NaN
,torch.equal()
也会返回 False
。
3.4 与 ==
运算符的区别
操作 | 返回值类型 | 核心逻辑 | 典型用途 |
---|---|---|---|
torch.equal(a, b) |
单一布尔值(bool ) |
判断整体是否完全相等 | 验证两个张量是否完全一致 |
a == b |
布尔张量(Tensor ) |
元素级比较,返回每个位置的结果 | 筛选特定位置的元素 |
对比示例:
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
print("a == b 的结果:", a == b) # 元素级比较
print("torch.equal(a, b):", torch.equal(a, b)) # 整体比较
输出:
a == b 的结果: tensor([ True, False, True])
torch.equal(a, b): False
3.5 使用场景与注意事项
场景:
- 验证张量是否完全一致:如检查模型参数在训练前后是否发生预期变化,或验证两个计算结果是否完全相同。
- 单元测试:在测试代码中,判断函数输出是否与预期张量完全一致。
- 调试:排查代码中张量是否在传递过程中被意外修改(形状或元素)。
注意:
- 形状优先:即使两个张量的元素“数量相同”但形状不同(如
(2,2)
和(4,)
),也会返回False
。 - 浮点数精度:对浮点数张量,微小的精度误差(如数值计算中的舍入误差)会导致返回
False
,此时应使用torch.allclose()
并设置合理的误差容限(如atol=1e-5
)。 NaN
特殊处理:含NaN
的张量几乎不可能被判定为相等,需单独处理NaN
位置(如先通过torch.isnan()
检测)。
总之,torch.equal()
是判断两个张量“完全一致性”的便捷工具,适合需要严格验证张量是否相同的场景。但使用时需注意形状、浮点数精度和 NaN
等特殊情况。
二、排序操作:torch.sort()
排序是将张量元素按升序或降序重新排列的操作。PyTorch 中 torch.sort()
函数不仅返回排序后的张量,还返回原元素在排序后的位置索引,这对后续分析元素来源至关重要。
函数原型:
torch.sort(input, dim=-1, descending=False, *, out=None) → (Tensor, Tensor)
参数说明:
input
:需要排序的输入张量。dim
(默认-1
):指定排序的维度(如dim=0
按行排序,dim=1
按列排序)。descending
(默认False
):排序方式,False
为升序,True
为降序。out
(可选):元组(sorted_tensor, indices)
,用于存储输出结果。
返回值:
- 元组
(sorted_tensor, indices)
:sorted_tensor
:排序后的张量,形状与输入一致。indices
:整数张量,记录原张量元素在排序后张量中的位置索引。
1. 一维张量排序
代码示例:
x = torch.tensor([3, 1, 4, 2, 5])
# 升序排序(默认)
sorted_x, indices = torch.sort(x)
print("升序排序结果:", sorted_x)
print("原索引位置:", indices)
# 降序排序
sorted_x_desc, indices_desc = torch.sort(x, descending=True)
print("降序排序结果:", sorted_x_desc)
print("原索引位置:", indices_desc)
运行结果:
升序排序结果: tensor([1, 2, 3, 4, 5])
原索引位置: tensor([1, 3, 0, 2, 4])
降序排序结果: tensor([5, 4, 3, 2, 1])
原索引位置: tensor([4, 2, 0, 3, 1])
结果分析:
- 升序排序后,
sorted_x
为[1, 2, 3, 4, 5]
,indices
表示原张量中元素的位置(如sorted_x[0] = 1
来自原张量的索引1
)。 - 降序排序通过
descending=True
实现,结果为[5, 4, 3, 2, 1]
,索引对应原元素位置。
2. 多维张量排序(指定维度)
多维张量排序需通过 dim
参数指定排序维度,不同维度的排序结果差异显著。
代码示例:
x = torch.tensor([[3, 1, 2],
[6, 4, 5]]) # 形状 (2, 3)
# 按列维度(dim=1)升序排序(每行内部排序)
sorted_row, indices_row = torch.sort(x, dim=1)
print("按行内元素排序结果:\n", sorted_row)
print("行内排序索引:\n", indices_row)
# 按行维度(dim=0)升序排序(每列内部排序)
sorted_col, indices_col = torch.sort(x, dim=0)
print("按列内元素排序结果:\n", sorted_col)
print("列内排序索引:\n", indices_col)
运行结果:
按行内元素排序结果:
tensor([[1, 2, 3],
[4, 5, 6]])
行内排序索引:
tensor([[1, 2, 0],
[1, 2, 0]])
按列内元素排序结果:
tensor([[3, 1, 2],
[6, 4, 5]])
列内排序索引:
tensor([[0, 0, 0],
[1, 1, 1]])
结果分析:
dim=1
表示按列维度排序(每行内部元素重新排列),第一行[3,1,2]
排序后为[1,2,3]
,索引[1,2,0]
对应原元素位置。dim=0
表示按行维度排序(每列内部元素重新排列),由于原张量列元素已按升序排列(如第一列[3,6]
),排序后结果不变,索引[0,1]
表示原行位置。
三、Top-K 选取:torch.topk()
在很多场景中,我们不需要对整个张量排序,只需获取最大或最小的 k
个元素(如推荐系统中的 Top-N 物品)。torch.topk()
函数可高效实现这一功能,无需全量排序,计算效率更高。
函数原型:
torch.topk(input, k, dim=-1, largest=True, sorted=True, *, out=None) → (Tensor, Tensor)
参数说明:
input
:输入张量。k
:需要选取的元素数量(必须满足1 ≤ k ≤ 输入张量在指定维度的大小
)。dim
(默认-1
):选取元素的维度。largest
(默认True
):选取方式,True
表示取最大的k
个元素,False
表示取最小的k
个元素。sorted
(默认True
):返回结果是否按大小排序(True
表示排序,False
表示不保证顺序)。out
(可选):元组(values, indices)
,用于存储输出结果。
返回值:
- 元组
(values, indices)
:values
:选取的k
个元素值。indices
:这些元素在原张量中的位置索引。
1. 一维张量的 Top-K 选取
代码示例:
x = torch.tensor([7, 2, 8, 1, 9, 3])
# 取最大的 3 个元素(默认)
top3_vals, top3_indices = torch.topk(x, k=3)
print("最大的 3 个元素:", top3_vals)
print("对应原索引:", top3_indices)
# 取最小的 2 个元素(largest=False)
bottom2_vals, bottom2_indices = torch.topk(x, k=2, largest=False)
print("最小的 2 个元素:", bottom2_vals)
print("对应原索引:", bottom2_indices)
运行结果:
最大的 3 个元素: tensor([9, 8, 7])
对应原索引: tensor([4, 2, 0])
最小的 2 个元素: tensor([1, 2])
对应原索引: tensor([3, 1])
结果分析:
k=3
且largest=True
时,返回最大的 3 个元素[9,8,7]
,其原索引分别为4
(x[4]=9
)、2
(x[2]=8
)、0
(x[0]=7
)。largest=False
时,返回最小的 2 个元素[1,2]
,对应原索引3
和1
。
2. 多维张量的 Top-K 选取(指定维度)
代码示例:
x = torch.tensor([[5, 2, 8],
[3, 9, 1]]) # 形状 (2, 3)
# 对每行取最大的 2 个元素(dim=1)
row_top2_vals, row_top2_indices = torch.topk(x, k=2, dim=1)
print("每行最大的 2 个元素:\n", row_top2_vals)
print("对应列索引:\n", row_top2_indices)
# 对每列取最小的 1 个元素(dim=0, largest=False)
col_bottom1_vals, col_bottom1_indices = torch.topk(x, k=1, dim=0, largest=False)
print("每列最小的 1 个元素:\n", col_bottom1_vals)
print("对应行索引:\n", col_bottom1_indices)
运行结果:
每行最大的 2 个元素:
tensor([[8, 5],
[9, 3]])
对应列索引:
tensor([[2, 0],
[1, 0]])
每列最小的 1 个元素:
tensor([[3],
[2],
[1]])
对应行索引:
tensor([[1],
[0],
[1]])
结果分析:
dim=1
表示按行取 Top-K,第一行[5,2,8]
最大的 2 个元素是8
(列索引 2)和5
(列索引 0)。dim=0
且largest=False
表示按列取最小元素,第一列[5,3]
最小元素是3
(行索引 1),以此类推。
四、K-th 值选取:torch.kthvalue()
torch.kthvalue()
用于获取张量中第 k 小的元素(按升序排列后的第 k 个元素,索引从 1 开始)。与 torch.topk()
不同,它聚焦于“特定排名”的元素,而非前 k 个元素。
函数原型:
torch.kthvalue(input, k, dim=-1, *, out=None) → (Tensor, Tensor)
参数说明:
input
:输入张量。k
:第 k 小的元素(1 ≤ k ≤ 输入张量在指定维度的大小
,注意 k 从 1 开始计数)。dim
(默认-1
):选取元素的维度。out
(可选):元组(value, index)
,用于存储输出结果。
返回值:
- 元组
(value, index)
:value
:第 k 小的元素值。index
:该元素在原张量中的位置索引。
1. 一维张量的 K-th 值
代码示例:
x = torch.tensor([3, 1, 4, 2, 5]) # 升序排列后为 [1, 2, 3, 4, 5]
# 取第 3 小的元素(k=3)
k3_val, k3_idx = torch.kthvalue(x, k=3)
print("第 3 小的元素值:", k3_val)
print("对应原索引:", k3_idx)
# 取第 1 小的元素(k=1,即最小值)
k1_val, k1_idx = torch.kthvalue(x, k=1)
print("第 1 小的元素值(最小值):", k1_val)
运行结果:
第 3 小的元素值: tensor(3)
对应原索引: tensor(0)
第 1 小的元素值(最小值): tensor(1)
结果分析:
- 原张量升序排列后为
[1,2,3,4,5]
,第 3 小的元素是3
,对应原张量的索引0
(x[0]=3
)。 k=1
时返回最小值1
,验证了函数的正确性。
2. 多维张量的 K-th 值(指定维度)
代码示例:
x = torch.tensor([[5, 2, 8],
[3, 9, 1]]) # 形状 (2, 3)
# 对每行取第 2 小的元素(dim=1, k=2)
row_k2_val, row_k2_idx = torch.kthvalue(x, k=2, dim=1)
print("每行第 2 小的元素值:", row_k2_val)
print("对应列索引:", row_k2_idx)
运行结果:
每行第 2 小的元素值: tensor([5, 3])
对应列索引: tensor([0, 0])
结果分析:
- 第一行
[5,2,8]
升序后为[2,5,8]
,第 2 小的元素是5
,对应原列索引0
。 - 第二行
[3,9,1]
升序后为[1,3,9]
,第 2 小的元素是3
,对应原列索引0
。
五、数据合法性校验:检测异常值
在深度学习中,张量中若存在 NaN
(非数)、Inf
(无穷大)等异常值,会导致模型训练发散或推理结果错误。因此,数据预处理和训练过程中需对异常值进行检测和处理。PyTorch 提供了专门的函数用于异常值检测。
1. 检测 NaN
:torch.isnan()
NaN
(Not a Number)通常由无效运算产生(如 0/0
、sqrt(-1)
等),torch.isnan()
可标记张量中所有 NaN
元素。
函数原型:
torch.isnan(input, *, out=None) → Tensor
参数说明:
input
:输入张量(通常为浮点型)。out
(可选):输出布尔张量,用于存储结果。
代码示例:
x = torch.tensor([1.0, float('nan'), 3.0, torch.nan]) # 包含 NaN 的张量
# 检测 NaN
is_nan = torch.isnan(x)
print("NaN 位置标记:", is_nan)
print("非 NaN 元素:", x[~is_nan]) # ~ 表示逻辑取反
运行结果:
NaN 位置标记: tensor([False, True, False, True])
非 NaN 元素: tensor([1., 3.])
结果分析:
float('nan')
(Python 原生)和torch.nan
(PyTorch 定义)均会被检测为NaN
。- 通过
x[~is_nan]
可筛选出所有非NaN
元素,实现数据清洗。
2. 检测 Inf
:torch.isinf()
Inf
(无穷大)由溢出运算产生(如 1/0
),torch.isinf()
可标记所有 Inf
元素(包括正无穷 +inf
和负无穷 -inf
)。
函数原型:
torch.isinf(input, *, out=None) → Tensor
代码示例:
x = torch.tensor([1.0, float('inf'), -float('inf'), 5.0]) # 包含 Inf 的张量
# 检测 Inf
is_inf = torch.isinf(x)
print("Inf 位置标记:", is_inf)
print("非 Inf 元素:", x[~is_inf])
运行结果:
Inf 位置标记: tensor([False, True, True, False])
非 Inf 元素: tensor([1., 5.])
结果分析:
- 正无穷
float('inf')
和负无穷-float('inf')
均被标记为True
。 - 筛选后仅保留正常元素
[1.0, 5.0]
。
3. 检测有限值:torch.isfinite()
有限值指的是既不是 NaN
也不是 Inf
(包括正、负无穷)的正常数值(如 1.0
、-3.5
等)。torch.isfinite()
函数返回一个布尔张量,其中 True
表示对应元素是有限值,False
表示元素是 NaN
或 Inf
。
函数原型:
torch.isfinite(input, *, out=None) → Tensor
参数说明:
input
:输入张量(通常为浮点型)。out
(可选):输出布尔张量,用于存储结果。
代码示例:
x = torch.tensor([
1.0, # 有限值
torch.nan, # NaN(非有限值)
float('inf'), # 正无穷(非有限值)
-float('inf'),# 负无穷(非有限值)
3.14 # 有限值
])
# 检测有限值
is_finite = torch.isfinite(x)
print("有限值位置标记:", is_finite)
print("所有有限值元素:", x[is_finite]) # 直接筛选有限值
运行结果:
有限值位置标记: tensor([ True, False, False, False, True])
所有有限值元素: tensor([1.0000, 3.1400])
结果分析:
torch.isfinite(x)
直接标记出所有有限值元素(1.0
和3.14
),返回True
;对NaN
、+inf
、-inf
均返回False
。与 “先检测异常值再取反”(
~(is_nan | is_inf)
)相比,torch.isfinite()
是更简洁的方式,直接筛选有限值元素。
torch.isfinite()
与其他检测函数的关系
torch.isfinite()
的结果等价于对 torch.isnan()
和 torch.isinf()
取反的逻辑与,即:
is_finite = ~torch.isnan(x) & ~torch.isinf(x)
但 torch.isfinite()
是专门优化的函数,计算效率更高,且代码更简洁。
4. 综合校验:同时检测 NaN
和 Inf
实际应用中,通常需要同时检测 NaN
和 Inf
,可通过逻辑运算符 |
(或)实现。
代码示例:
x = torch.tensor([2.0, torch.nan, float('inf'), -float('inf'), 3.0])
# 同时检测 NaN 和 Inf
is_abnormal = torch.isnan(x) | torch.isinf(x)
print("异常值位置标记:", is_abnormal)
print("正常元素:", x[~is_abnormal])
运行结果:
异常值位置标记: tensor([False, True, True, True, False])
正常元素: tensor([2., 3.])
结果分析:
torch.isnan(x) | torch.isinf(x)
标记所有NaN
或Inf
元素,返回布尔张量。- 筛选后仅保留正常元素
[2.0, 3.0]
,确保数据合法性。
六、总结与应用场景
操作类型 | 核心函数/运算符 | 关键功能 | 典型应用场景 |
---|---|---|---|
比较运算 | == /!= /torch.eq() 等 |
元素级关系判断(等于、大于等) | 条件筛选、掩码生成、结果验证 |
排序 | torch.sort() |
按指定维度升序/降序排序,返回索引 | 全量排序、元素顺序分析 |
Top-K 选取 | torch.topk() |
高效获取最大/最小的 k 个元素 | 推荐系统、模型推理加速 |
K-th 值选取 | torch.kthvalue() |
获取第 k 小的元素 | 统计分析(如中位数计算) |
数据校验 | torch.isnan() /torch.isinf() |
检测 NaN /Inf 异常值 |
数据预处理、训练过程异常监控 |
掌握这些操作后,你可以:
- 快速筛选符合条件的张量元素;
- 高效获取排序后的结果或Top-K元素,优化模型推理;
- 检测并处理异常值,保障模型训练稳定性。
保数据合法性。
六、总结与应用场景
操作类型 | 核心函数/运算符 | 关键功能 | 典型应用场景 |
---|---|---|---|
比较运算 | == /!= /torch.eq() 等 |
元素级关系判断(等于、大于等) | 条件筛选、掩码生成、结果验证 |
排序 | torch.sort() |
按指定维度升序/降序排序,返回索引 | 全量排序、元素顺序分析 |
Top-K 选取 | torch.topk() |
高效获取最大/最小的 k 个元素 | 推荐系统、模型推理加速 |
K-th 值选取 | torch.kthvalue() |
获取第 k 小的元素 | 统计分析(如中位数计算) |
数据校验 | torch.isnan() /torch.isinf() |
检测 NaN /Inf 异常值 |
数据预处理、训练过程异常监控 |
掌握这些操作后,你可以:
- 快速筛选符合条件的张量元素;
- 高效获取排序后的结果或Top-K元素,优化模型推理;
- 检测并处理异常值,保障模型训练稳定性。
这些操作是深度学习开发的基础工具,无论是数据预处理、模型训练还是结果分析,都离不开它们的灵活应用。建议结合实际场景多做练习,加深理解。