PyTorch 张量核心操作——比较、排序与数据校验

发布于:2025-08-03 ⋅ 阅读:(16) ⋅ 点赞:(0)

PyTorch 张量核心操作——比较、排序与数据校验

在深度学习开发中,张量(Tensor)的比较、排序和数据校验是基础且高频的操作。无论是模型训练中的数据预处理,还是推理阶段的结果分析,这些操作都扮演着重要角色。本文将系统讲解 PyTorch 中与张量相关的比较运算、排序方法、Top-K 选取、K-th 值提取以及数据合法性校验,涵盖函数原型、参数详解、代码示例和结果分析,帮助初学者全面掌握这些核心技能。

一、张量比较运算:判断元素间的关系

比较运算用于判断张量元素之间的大小关系或相等性,返回与输入张量形状相同的布尔型张量(True/False)。PyTorch 提供了丰富的比较运算符和函数,支持元素级别的逐元素比较。

1. 基础比较运算符

PyTorch 支持与 Python 类似的比较运算符,包括 ==(等于)、!=(不等于)、>(大于)、<(小于)、>=(大于等于)、<=(小于等于)。这些运算符均为元素级运算,即对两个张量的对应元素逐一进行比较。

运算符特点:

  • 要求两个张量形状相同或可广播(广播机制见后文补充)。
  • 返回布尔型张量,True 表示满足条件,False 表示不满足。

运算原理:

张量的比较运算遵循 “位置对应” 原则 :两个张量必须形状相同(或可通过广播机制扩展为相同形状),然后对相同位置的元素逐一进行比较,最终生成一个形状相同的布尔张量(True/False)。

  • “大小” 的含义:张量中元素的大小就是其数值大小(如 3 > 2-1 < 0 等),与元素在张量中的位置无关。
  • 布尔张量的意义:结果中 True 表示对应位置的元素满足比较条件,False 表示不满足。

运算演示:

示例:比较两个一维张量 a = [1, 3, 5, 7]b = [2, 3, 4, 8]

  • 步骤 1:确认张量形状

    • a 的形状:(4,)(1 维,4 个元素)

    • b 的形状:(4,)(1 维,4 个元素)
      形状相同,可直接比较(无需广播)。

  • 步骤 2:逐元素比较(以 a > b 为例)

    比较逻辑:对每个索引 i(0 ≤ i < 4),判断 a[i] > b[i] 是否成立。

    • i=0a[0] = 1b[0] = 21 > 2?→ False

    • i=1a[1] = 3b[1] = 33 > 3?→ False

    • i=2a[2] = 5b[2] = 45 > 4?→ True

    • i=3a[3] = 7b[3] = 87 > 8?→ False

  • 步骤 3:生成结果张量

    • 将上述判断结果按原位置组合,得到布尔张量:a > b 的结果为 [False, False, True, False]

代码示例:

import torch

# 定义两个形状相同的张量
a = torch.tensor([1, 3, 5, 7])
b = torch.tensor([2, 3, 4, 8])

# 比较运算
print("a == b:", a == b)  # 等于
print("a != b:", a != b)  # 不等于
print("a > b: ", a > b)   # 大于
print("a < b: ", a < b)   # 小于
print("a >= b:", a >= b)  # 大于等于
print("a <= b:", a <= b)  # 小于等于

运行结果:

a == b: tensor([False,  True, False, False])
a != b: tensor([ True, False,  True,  True])
a > b:  tensor([False, False,  True, False])
a < b:  tensor([ True, False, False,  True])
a >= b: tensor([False,  True,  True, False])
a <= b: tensor([ True,  True, False,  True])

结果分析:

  • 每个运算符都对 ab 的对应元素进行比较(如 a[0]=1b[0]=2 比较,1 < 2a < b 的第 0 位为 True)。
  • 布尔张量的形状与输入张量一致(均为 (4,)),便于后续基于条件筛选元素(如 a[a > b] 可提取 a 中大于 b 对应元素的值)。

2. 比较函数:torch.eq()torch.ne()

除运算符外,PyTorch 还提供了对应的函数形式,如 torch.eq()(等于)、torch.ne()(不等于)、torch.gt()(大于)、torch.lt()(小于)、torch.ge()(大于等于)、torch.le()(小于等于)。这些函数与运算符功能一致,但支持更灵活的参数设置(如广播)。

函数原型:

torch.eq(input, other, *, out=None) → Tensor
torch.ne(input, other, *, out=None) → Tensor
torch.gt(input, other, *, out=None) → Tensor
# 其余函数参数类似

参数说明:

  • input:输入张量(第一个比较对象)。
  • other:第二个比较对象(可以是张量或标量)。
  • out(可选):输出张量,用于存储结果(需与预期输出形状一致)。

代码示例(支持广播机制):

# 形状不同但可广播的张量比较
a = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状 (2, 3)
b = torch.tensor([3])  # 标量张量(形状 ()),可广播为 (2, 3)

# 使用函数进行比较(等价于 a < 3)
lt_result = torch.lt(a, b)  
print("a < 3 的结果:\n", lt_result)

运行结果:

a < 3 的结果:
 tensor([[ True,  True, False],
        [False, False, False]])

结果分析:

  • b 是标量张量,通过广播机制自动扩展为与 a 同形状的张量 [[3, 3, 3], [3, 3, 3]]
  • torch.lt(a, b) 逐元素比较 a 和扩展后的 b,返回布尔张量(a 中元素小于 3 的位置为 True)。

比较运算的核心应用:

  • 条件筛选:通过 tensor[布尔张量] 提取满足条件的元素(如 a[a > 5] 提取 a 中大于 5 的元素)。
  • 掩码操作:生成掩码张量用于数据过滤或加权计算。
  • 结果验证:在模型推理中判断预测结果与标签的匹配情况(如计算准确率时统计 pred == label 的数量)。

3. 张量相等性判断 torch.equal()

在 PyTorch 中,torch.equal() 是一个用于判断两个张量是否完全相等的函数。它与我们之前讲过的元素级比较运算符(如 ==)不同,后者返回一个布尔张量,而 torch.equal() 会返回一个单一的布尔值(TrueFalse),表示两个张量是否在所有元素和形状上都完全一致。

3.1 torch.equal() 函数原型
torch.equal(input1, input2)bool

参数说明:

  • input1:第一个待比较的张量。
  • input2:第二个待比较的张量。

返回值:

  • 布尔值(TrueFalse):如果两个张量的形状相同所有对应元素都相等,则返回 True;否则返回 False
3.2 核心特点:判断“整体相等性”

torch.equal() 的核心是整体判断,而非元素级判断。它有两个严格条件:

  1. 两个张量的形状必须完全相同
  2. 两个张量所有对应位置的元素必须完全相等

只有同时满足这两个条件,才会返回 True

3.3 代码示例与结果分析

示例 1:形状相同且元素全相等

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])

print(torch.equal(a, b))  # 输出:True

分析ab 形状均为 (3,),且所有元素对应相等,因此返回 True

示例 2:形状相同但元素不全相等

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])

print(torch.equal(a, b))  # 输出:False

分析:虽然形状相同,但 a[1]=2b[1]=4 不相等,因此返回 False

示例 3:形状不同(即使元素“看起来”对应)

a = torch.tensor([[1, 2], [3, 4]])  # 形状 (2, 2)
b = torch.tensor([1, 2, 3, 4])      # 形状 (4,)

print(torch.equal(a, b))  # 输出:False

分析a 是 2×2 的二维张量,b 是长度为 4 的一维张量,形状不同,直接返回 False

示例 4:浮点数的相等判断(需注意精度)

a = torch.tensor([1.0, 2.0])
b = torch.tensor([1.0 + 1e-9, 2.0])  # 第一个元素有微小差异

print(torch.equal(a, b))  # 输出:False

分析:浮点数由于精度问题,即使差异极小(如 1e-9),torch.equal() 也会判定为不相等。如果需要忽略微小误差,应使用 torch.allclose()(后续会介绍)。

示例 5:包含 NaN 的张量(特殊情况)

a = torch.tensor([1.0, torch.nan])
b = torch.tensor([1.0, torch.nan])

print(torch.equal(a, b))  # 输出:False

分析:由于 NaN 与任何值(包括自身)都不相等(数学定义),因此即使两个张量都含 NaNtorch.equal() 也会返回 False

3.4 与 == 运算符的区别
操作 返回值类型 核心逻辑 典型用途
torch.equal(a, b) 单一布尔值(bool 判断整体是否完全相等 验证两个张量是否完全一致
a == b 布尔张量(Tensor 元素级比较,返回每个位置的结果 筛选特定位置的元素

对比示例:

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])

print("a == b 的结果:", a == b)         # 元素级比较
print("torch.equal(a, b):", torch.equal(a, b))  # 整体比较

输出

a == b 的结果: tensor([ True, False,  True])
torch.equal(a, b): False
3.5 使用场景与注意事项

场景:

  1. 验证张量是否完全一致:如检查模型参数在训练前后是否发生预期变化,或验证两个计算结果是否完全相同。
  2. 单元测试:在测试代码中,判断函数输出是否与预期张量完全一致。
  3. 调试:排查代码中张量是否在传递过程中被意外修改(形状或元素)。

注意:

  1. 形状优先:即使两个张量的元素“数量相同”但形状不同(如 (2,2)(4,)),也会返回 False
  2. 浮点数精度:对浮点数张量,微小的精度误差(如数值计算中的舍入误差)会导致返回 False,此时应使用 torch.allclose() 并设置合理的误差容限(如 atol=1e-5)。
  3. NaN 特殊处理:含 NaN 的张量几乎不可能被判定为相等,需单独处理 NaN 位置(如先通过 torch.isnan() 检测)。

总之,torch.equal() 是判断两个张量“完全一致性”的便捷工具,适合需要严格验证张量是否相同的场景。但使用时需注意形状、浮点数精度和 NaN 等特殊情况。

二、排序操作:torch.sort()

排序是将张量元素按升序或降序重新排列的操作。PyTorch 中 torch.sort() 函数不仅返回排序后的张量,还返回原元素在排序后的位置索引,这对后续分析元素来源至关重要。

函数原型:

torch.sort(input, dim=-1, descending=False, *, out=None)(Tensor, Tensor)

参数说明:

  • input:需要排序的输入张量。
  • dim(默认 -1):指定排序的维度(如 dim=0 按行排序,dim=1 按列排序)。
  • descending(默认 False):排序方式,False 为升序,True 为降序。
  • out(可选):元组 (sorted_tensor, indices),用于存储输出结果。

返回值:

  • 元组 (sorted_tensor, indices)
    • sorted_tensor:排序后的张量,形状与输入一致。
    • indices:整数张量,记录原张量元素在排序后张量中的位置索引。

1. 一维张量排序

代码示例:

x = torch.tensor([3, 1, 4, 2, 5])

# 升序排序(默认)
sorted_x, indices = torch.sort(x)  
print("升序排序结果:", sorted_x)
print("原索引位置:", indices)

# 降序排序
sorted_x_desc, indices_desc = torch.sort(x, descending=True)  
print("降序排序结果:", sorted_x_desc)
print("原索引位置:", indices_desc)

运行结果:

升序排序结果: tensor([1, 2, 3, 4, 5])
原索引位置: tensor([1, 3, 0, 2, 4])
降序排序结果: tensor([5, 4, 3, 2, 1])
原索引位置: tensor([4, 2, 0, 3, 1])

结果分析:

  • 升序排序后,sorted_x[1, 2, 3, 4, 5]indices 表示原张量中元素的位置(如 sorted_x[0] = 1 来自原张量的索引 1)。
  • 降序排序通过 descending=True 实现,结果为 [5, 4, 3, 2, 1],索引对应原元素位置。

2. 多维张量排序(指定维度)

多维张量排序需通过 dim 参数指定排序维度,不同维度的排序结果差异显著。

代码示例:

x = torch.tensor([[3, 1, 2], 
                  [6, 4, 5]])  # 形状 (2, 3)

# 按列维度(dim=1)升序排序(每行内部排序)
sorted_row, indices_row = torch.sort(x, dim=1)  
print("按行内元素排序结果:\n", sorted_row)
print("行内排序索引:\n", indices_row)

# 按行维度(dim=0)升序排序(每列内部排序)
sorted_col, indices_col = torch.sort(x, dim=0)  
print("按列内元素排序结果:\n", sorted_col)
print("列内排序索引:\n", indices_col)

运行结果:

按行内元素排序结果:
 tensor([[1, 2, 3],
        [4, 5, 6]])
行内排序索引:
 tensor([[1, 2, 0],
        [1, 2, 0]])
按列内元素排序结果:
 tensor([[3, 1, 2],
        [6, 4, 5]])
列内排序索引:
 tensor([[0, 0, 0],
        [1, 1, 1]])

结果分析:

  • dim=1 表示按列维度排序(每行内部元素重新排列),第一行 [3,1,2] 排序后为 [1,2,3],索引 [1,2,0] 对应原元素位置。
  • dim=0 表示按行维度排序(每列内部元素重新排列),由于原张量列元素已按升序排列(如第一列 [3,6]),排序后结果不变,索引 [0,1] 表示原行位置。

三、Top-K 选取:torch.topk()

在很多场景中,我们不需要对整个张量排序,只需获取最大或最小的 k 个元素(如推荐系统中的 Top-N 物品)。torch.topk() 函数可高效实现这一功能,无需全量排序,计算效率更高。

函数原型:

torch.topk(input, k, dim=-1, largest=True, sorted=True, *, out=None)(Tensor, Tensor)

参数说明:

  • input:输入张量。
  • k:需要选取的元素数量(必须满足 1 ≤ k ≤ 输入张量在指定维度的大小)。
  • dim(默认 -1):选取元素的维度。
  • largest(默认 True):选取方式,True 表示取最大的 k 个元素,False 表示取最小的 k 个元素。
  • sorted(默认 True):返回结果是否按大小排序(True 表示排序,False 表示不保证顺序)。
  • out(可选):元组 (values, indices),用于存储输出结果。

返回值:

  • 元组 (values, indices)
    • values:选取的 k 个元素值。
    • indices:这些元素在原张量中的位置索引。

1. 一维张量的 Top-K 选取

代码示例:

x = torch.tensor([7, 2, 8, 1, 9, 3])

# 取最大的 3 个元素(默认)
top3_vals, top3_indices = torch.topk(x, k=3)  
print("最大的 3 个元素:", top3_vals)
print("对应原索引:", top3_indices)

# 取最小的 2 个元素(largest=False)
bottom2_vals, bottom2_indices = torch.topk(x, k=2, largest=False)  
print("最小的 2 个元素:", bottom2_vals)
print("对应原索引:", bottom2_indices)

运行结果:

最大的 3 个元素: tensor([9, 8, 7])
对应原索引: tensor([4, 2, 0])
最小的 2 个元素: tensor([1, 2])
对应原索引: tensor([3, 1])

结果分析:

  • k=3largest=True 时,返回最大的 3 个元素 [9,8,7],其原索引分别为 4x[4]=9)、2x[2]=8)、0x[0]=7)。
  • largest=False 时,返回最小的 2 个元素 [1,2],对应原索引 31

2. 多维张量的 Top-K 选取(指定维度)

代码示例:

x = torch.tensor([[5, 2, 8], 
                  [3, 9, 1]])  # 形状 (2, 3)

# 对每行取最大的 2 个元素(dim=1)
row_top2_vals, row_top2_indices = torch.topk(x, k=2, dim=1)  
print("每行最大的 2 个元素:\n", row_top2_vals)
print("对应列索引:\n", row_top2_indices)

# 对每列取最小的 1 个元素(dim=0, largest=False)
col_bottom1_vals, col_bottom1_indices = torch.topk(x, k=1, dim=0, largest=False)  
print("每列最小的 1 个元素:\n", col_bottom1_vals)
print("对应行索引:\n", col_bottom1_indices)

运行结果:

每行最大的 2 个元素:
 tensor([[8, 5],
        [9, 3]])
对应列索引:
 tensor([[2, 0],
        [1, 0]])
每列最小的 1 个元素:
 tensor([[3],
        [2],
        [1]])
对应行索引:
 tensor([[1],
        [0],
        [1]])

结果分析:

  • dim=1 表示按行取 Top-K,第一行 [5,2,8] 最大的 2 个元素是 8(列索引 2)和 5(列索引 0)。
  • dim=0largest=False 表示按列取最小元素,第一列 [5,3] 最小元素是 3(行索引 1),以此类推。

四、K-th 值选取:torch.kthvalue()

torch.kthvalue() 用于获取张量中第 k 小的元素(按升序排列后的第 k 个元素,索引从 1 开始)。与 torch.topk() 不同,它聚焦于“特定排名”的元素,而非前 k 个元素。

函数原型:

torch.kthvalue(input, k, dim=-1, *, out=None)(Tensor, Tensor)

参数说明:

  • input:输入张量。
  • k:第 k 小的元素(1 ≤ k ≤ 输入张量在指定维度的大小,注意 k 从 1 开始计数)。
  • dim(默认 -1):选取元素的维度。
  • out(可选):元组 (value, index),用于存储输出结果。

返回值:

  • 元组 (value, index)
    • value:第 k 小的元素值。
    • index:该元素在原张量中的位置索引。

1. 一维张量的 K-th 值

代码示例:

x = torch.tensor([3, 1, 4, 2, 5])  # 升序排列后为 [1, 2, 3, 4, 5]

# 取第 3 小的元素(k=3)
k3_val, k3_idx = torch.kthvalue(x, k=3)  
print("第 3 小的元素值:", k3_val)
print("对应原索引:", k3_idx)

# 取第 1 小的元素(k=1,即最小值)
k1_val, k1_idx = torch.kthvalue(x, k=1)  
print("第 1 小的元素值(最小值):", k1_val)

运行结果:

第 3 小的元素值: tensor(3)
对应原索引: tensor(0)
第 1 小的元素值(最小值): tensor(1)

结果分析:

  • 原张量升序排列后为 [1,2,3,4,5],第 3 小的元素是 3,对应原张量的索引 0x[0]=3)。
  • k=1 时返回最小值 1,验证了函数的正确性。

2. 多维张量的 K-th 值(指定维度)

代码示例:

x = torch.tensor([[5, 2, 8], 
                  [3, 9, 1]])  # 形状 (2, 3)

# 对每行取第 2 小的元素(dim=1, k=2)
row_k2_val, row_k2_idx = torch.kthvalue(x, k=2, dim=1)  
print("每行第 2 小的元素值:", row_k2_val)
print("对应列索引:", row_k2_idx)

运行结果:

每行第 2 小的元素值: tensor([5, 3])
对应列索引: tensor([0, 0])

结果分析:

  • 第一行 [5,2,8] 升序后为 [2,5,8],第 2 小的元素是 5,对应原列索引 0
  • 第二行 [3,9,1] 升序后为 [1,3,9],第 2 小的元素是 3,对应原列索引 0

五、数据合法性校验:检测异常值

在深度学习中,张量中若存在 NaN(非数)、Inf(无穷大)等异常值,会导致模型训练发散或推理结果错误。因此,数据预处理和训练过程中需对异常值进行检测和处理。PyTorch 提供了专门的函数用于异常值检测。

1. 检测 NaNtorch.isnan()

NaN(Not a Number)通常由无效运算产生(如 0/0sqrt(-1) 等),torch.isnan() 可标记张量中所有 NaN 元素。

函数原型:

torch.isnan(input, *, out=None) → Tensor

参数说明:

  • input:输入张量(通常为浮点型)。
  • out(可选):输出布尔张量,用于存储结果。

代码示例:

x = torch.tensor([1.0, float('nan'), 3.0, torch.nan])  # 包含 NaN 的张量

# 检测 NaN
is_nan = torch.isnan(x)  
print("NaN 位置标记:", is_nan)
print("非 NaN 元素:", x[~is_nan])  # ~ 表示逻辑取反

运行结果:

NaN 位置标记: tensor([False,  True, False,  True])
非 NaN 元素: tensor([1., 3.])

结果分析:

  • float('nan')(Python 原生)和 torch.nan(PyTorch 定义)均会被检测为 NaN
  • 通过 x[~is_nan] 可筛选出所有非 NaN 元素,实现数据清洗。

2. 检测 Inftorch.isinf()

Inf(无穷大)由溢出运算产生(如 1/0),torch.isinf() 可标记所有 Inf 元素(包括正无穷 +inf 和负无穷 -inf)。

函数原型:

torch.isinf(input, *, out=None) → Tensor

代码示例:

x = torch.tensor([1.0, float('inf'), -float('inf'), 5.0])  # 包含 Inf 的张量

# 检测 Inf
is_inf = torch.isinf(x)  
print("Inf 位置标记:", is_inf)
print("非 Inf 元素:", x[~is_inf])

运行结果:

Inf 位置标记: tensor([False,  True,  True, False])
非 Inf 元素: tensor([1., 5.])

结果分析:

  • 正无穷 float('inf') 和负无穷 -float('inf') 均被标记为 True
  • 筛选后仅保留正常元素 [1.0, 5.0]

3. 检测有限值:torch.isfinite()

有限值指的是既不是 NaN 也不是 Inf(包括正、负无穷)的正常数值(如 1.0-3.5 等)。torch.isfinite() 函数返回一个布尔张量,其中 True 表示对应元素是有限值,False 表示元素是 NaNInf

函数原型:

torch.isfinite(input, *, out=None) → Tensor

参数说明:

  • input:输入张量(通常为浮点型)。
  • out(可选):输出布尔张量,用于存储结果。

代码示例:

x = torch.tensor([
    1.0,          # 有限值
    torch.nan,    # NaN(非有限值)
    float('inf'), # 正无穷(非有限值)
    -float('inf'),# 负无穷(非有限值)
    3.14          # 有限值
])

# 检测有限值
is_finite = torch.isfinite(x)  
print("有限值位置标记:", is_finite)
print("所有有限值元素:", x[is_finite])  # 直接筛选有限值

运行结果:

有限值位置标记: tensor([ True, False, False, False,  True])
所有有限值元素: tensor([1.0000, 3.1400])

结果分析:

  • torch.isfinite(x) 直接标记出所有有限值元素(1.03.14),返回 True;对 NaN+inf-inf 均返回 False

  • 与 “先检测异常值再取反”(~(is_nan | is_inf))相比,torch.isfinite() 是更简洁的方式,直接筛选有限值元素。

torch.isfinite() 与其他检测函数的关系

torch.isfinite() 的结果等价于对 torch.isnan()torch.isinf() 取反的逻辑与,即:

is_finite = ~torch.isnan(x) & ~torch.isinf(x)

torch.isfinite() 是专门优化的函数,计算效率更高,且代码更简洁。

4. 综合校验:同时检测 NaNInf

实际应用中,通常需要同时检测 NaNInf,可通过逻辑运算符 |(或)实现。

代码示例:

x = torch.tensor([2.0, torch.nan, float('inf'), -float('inf'), 3.0])

# 同时检测 NaN 和 Inf
is_abnormal = torch.isnan(x) | torch.isinf(x)  
print("异常值位置标记:", is_abnormal)
print("正常元素:", x[~is_abnormal])

运行结果:

异常值位置标记: tensor([False,  True,  True,  True, False])
正常元素: tensor([2., 3.])

结果分析:

  • torch.isnan(x) | torch.isinf(x) 标记所有 NaNInf 元素,返回布尔张量。
  • 筛选后仅保留正常元素 [2.0, 3.0],确保数据合法性。

六、总结与应用场景

操作类型 核心函数/运算符 关键功能 典型应用场景
比较运算 ==/!=/torch.eq() 元素级关系判断(等于、大于等) 条件筛选、掩码生成、结果验证
排序 torch.sort() 按指定维度升序/降序排序,返回索引 全量排序、元素顺序分析
Top-K 选取 torch.topk() 高效获取最大/最小的 k 个元素 推荐系统、模型推理加速
K-th 值选取 torch.kthvalue() 获取第 k 小的元素 统计分析(如中位数计算)
数据校验 torch.isnan()/torch.isinf() 检测 NaN/Inf 异常值 数据预处理、训练过程异常监控

掌握这些操作后,你可以:

  • 快速筛选符合条件的张量元素;
  • 高效获取排序后的结果或Top-K元素,优化模型推理;
  • 检测并处理异常值,保障模型训练稳定性。

保数据合法性。

六、总结与应用场景

操作类型 核心函数/运算符 关键功能 典型应用场景
比较运算 ==/!=/torch.eq() 元素级关系判断(等于、大于等) 条件筛选、掩码生成、结果验证
排序 torch.sort() 按指定维度升序/降序排序,返回索引 全量排序、元素顺序分析
Top-K 选取 torch.topk() 高效获取最大/最小的 k 个元素 推荐系统、模型推理加速
K-th 值选取 torch.kthvalue() 获取第 k 小的元素 统计分析(如中位数计算)
数据校验 torch.isnan()/torch.isinf() 检测 NaN/Inf 异常值 数据预处理、训练过程异常监控

掌握这些操作后,你可以:

  • 快速筛选符合条件的张量元素;
  • 高效获取排序后的结果或Top-K元素,优化模型推理;
  • 检测并处理异常值,保障模型训练稳定性。

这些操作是深度学习开发的基础工具,无论是数据预处理、模型训练还是结果分析,都离不开它们的灵活应用。建议结合实际场景多做练习,加深理解。