Pytorch:张量的索引操作

发布于:2024-04-19 ⋅ 阅读:(16) ⋅ 点赞:(0)

参考神君

一、索引

1.使用整数索引访问单个元素
import torch
x=torch.tensor([[1,2,3],[4,5,6]])
lst = [[1,2,3],[4,5,6]]
print(x[0][1])#等价于x[0,1]
print(lst[0][1])

输出:

tensor(2)
2
2.使用多个整数索引访问多个元素*

list没有这种索引方式。这种索引方式,通过在索引中,给定多个列表或多个张量,索引按这个列表或张量的次序访问元素。

a.示例详解

考虑以下的 3x3 张量:

1 2 3
4 5 6
7 8 9

假设我们想从中选择元素 26。这两个元素分别位于:

  • 2 在第0行第1列
  • 6 在第1行第2列

要使用 fancy indexing 来选择这些元素,可以做如下操作:

import torch

# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 定义行索引和列索引
rows = torch.tensor([0, 1])
columns = torch.tensor([1, 2])

# 使用行索引和列索引进行选择
selected_elements = tensor[rows, columns]

print(selected_elements)  # 输出 tensor([2, 6])

在上面的例子中,rowscolumns 是两个张量,分别指定了要访问的行和列的索引。当你传递 rowscolumns 到原始张量 tensor 时,PyTorch 会解释这样的索引方式:

  • 对于每对 (row, column) 索引,选择对应的元素。
  • rows[0]columns[0] 组合指定了元素 tensor[0, 1](即 2)。
  • rows[1]columns[1] 组合指定了元素 tensor[1, 2](即 6)。

这种索引非常灵活,允许你从张量中快速选择一个不规则的元素集合。这对于机器学习和数据预处理任务尤其有用,因为经常需要从数据集中提取特定的样本或特征。

b.示例
import torch
x=torch.tensor([[1,2,3],[4,5,6]])
print(x[[0,1],[1,1]])

输出:

tensor([2, 5])#x[0][1]和x[1][1]

当我们只给定一个维度的时候,也是一样的。

y=torch.tensor([1,2,3,4,5,6])
print(y[[0,1]])

输出:

tensor([1, 2]) # y[0]和y[1]
3.使用负数索引从张量的末尾开始计数

如果某个维度的索引是-1,则在该维取末尾元素。

import torch
x=torch.tensor([[1,2,3],[4,5,6]])
print(x[[0,1],[1,-1]])

输出:

tensor([2, 6])
4.使用布尔索引访问满足条件的元素*

list没有这种索引方式。

  高级索引是一种在 PyTorch 和 NumPy 中常用的索引方法,它允许你从数组或张量中选择复杂的、非连续的数据子集。高级索引可以通过传递整数数组、张量或列表来实现,而这些索引方式相比基本的切片提供了更大的灵活性。在 PyTorch 中,使用高级索引时,索引操作的结果通常会形成一个新的张量,不与原始数据共享内存。

高级索引的几种常见形式:

  1. 整数数组索引
    使用整数数组进行索引时,你可以指定要访问的每个维度上的索引位置。这种方式可以从张量中选择任意位置的数据,而这些数据可以是非连续和非规则的。

  2. 布尔(掩码)索引
    布尔索引允许你使用布尔数组(通常是逻辑条件的结果)来选择张量的元素。这种方法非常适用于基于条件的筛选。

a.张量的元素级布尔操作

在 PyTorch(以及其他类似的库,如 NumPy)中,使用张量进行布尔表达式的操作本质上是一种称为“元素级”或“元素对元素”的操作。当你在布尔表达式中使用张量时,PyTorch 会自动应用广播和矢量化操作,使得表达式能够逐元素地计算结果。这种处理方式使得代码不仅可读性好,而且效率高,非常适合科学计算和机器学习任务。

元素级布尔操作:

在 PyTorch 中,布尔操作(例如比较操作 <, >, <=, >=, ==, !=)都是元素级的。这意味着每个操作都是在输入张量的对应元素间独立进行的,并生成一个布尔类型的张量,其中的每个元素都是单个比较的结果。

原理解释:

  1. 矢量化
    矢量化是指使用优化的库例程一次处理整个数组(或张量),而不是在Python层面上使用循环处理数组的每个元素。这减少了循环的开销,提高了执行速度,尤其是在底层使用如C/C++等编译语言实现的情况下。

  2. 广播
    广播是一种灵活处理不同形状张量的方法。当在两个不同大小的张量上进行操作时,较小的张量会自动“扩展”其维度以匹配较大张量的形状。例如,如果你有一个形状为(3,1)的张量和一个形状为(1,4)的张量,那么在操作中,每个张量都会广播到(3,4)以进行元素级操作。

示例:

考虑两个张量 AB,大小分别为 (3,3)(3,)

import torch

A = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
B = torch.tensor([2, 2, 2])

# 元素级比较
C = A > B
print(C)

输出:

tensor([[False, False,  True],
        [ True,  True,  True],
        [ True,  True,  True]])

在这里,B 被广播到与 A 相同的形状 (3,3),每个元素的比较都是独立进行的。

b.布尔索引示例

布尔索引可以用来选择满足特定条件的元素。例如,选择张量中所有大于5的元素:

# 创建一个 3x4 的张量
tensor = torch.tensor([[1, 2, 3, 4],
                       [5, 6, 7, 8],
                       [9, 10, 11, 12]])

# 布尔索引
mask = tensor > 5
selected_elements = tensor[mask]
print(mask)
print(selected_elements)

输出:

tensor([[False, False, False, False],
        [False,  True,  True,  True],
        [ True,  True,  True,  True]])
tensor([ 6,  7,  8,  9, 10, 11, 12])
5.torch.where()函数根据条件选择元素

torch.where() 是一个非常有用的函数,在 PyTorch 中用于根据条件从两个张量中选择元素。该函数的工作原理相似于 NumPy 中的 np.where(),允许在条件为真时从一个张量选择元素,条件为假时从另一个张量选择元素。

a.函数原型

torch.where() 函数的基本语法如下:

torch.where(condition, x, y)
  • condition: 一个布尔张量,其中的每个元素都对应于 xy 张量中相应位置的条件检查。
  • x: 当条件为真时将从这个张量中选择元素。
  • y:当条件为假时将从这个张量中选择元素。

结果张量的每个位置会根据 condition 张量在相应位置的值是真还是假,从 xy 张量中选择值。


torch.where(condition)
  • 当仅提供 condition参数时,此函数返回满足条件的元素的索引。这可以用于找出满足特定条件的所有元素的位置。返回的是一个元组,其中每个元素是一个张量,分别代表满足条件的元素在各个维度上的索引。
# 创建温度张量
temperatures = torch.tensor([-5, 13, -2, 8, -1])
condition = temperatures < 0
indexs = torch.where(condition)
corrected_temperatures = temperatures[indexs]
print(corrected_temperatures)

输出:

tensor([-5, -2, -1])
b.示例

假设有一个张量代表温度值,想要将所有低于零度的值设置为零:

# 创建温度张量
temperatures = torch.tensor([-5, 13, -2, 8, -1])

# 使用 torch.where() 来修正负值
# 注意,这里的temperatures<0是之前提到的元素级布尔操作,生成的是tensor([True,False,True,False,True])
corrected_temperatures = torch.where(temperatures < 0, torch.tensor(0), temperatures)

print(corrected_temperatures)

or:

# 创建温度张量
temperatures = torch.tensor([-5, 13, -2, 8, -1])
condition = temperatures < 0
indexs = torch.where(condition)
for i in indexs:
	temperatures[i]=0
print(temperatures)

输出:

tensor([ 0, 13,  0,  8,  0])

import torch
A = torch.tensor([5,1,2])
B = torch.tensor([2,4,7])
C = torch.where(A<B,A,B)
print(C)

输出:

tensor([2, 1, 2])
6. torch.take()函数按索引从张量中选择元素

在 PyTorch 中,.take() 函数是一个实用的张量操作,它用于从输入张量中按照指定的索引来提取元素。这个函数允许你将输入张量视为一维张量,并使用一维索引从中选择元素。这种方式特别适用于从多维张量中按特定顺序选择元素,而不必担心张量的原始维度。

a. .take() 函数的基本用法

函数如下:

tensor.take(indices)  or torch.take(tensor,indices)
  • indices:一个包含要提取的元素索引的一维张量。

这个函数按照 indices 提供的索引从输入张量中取出元素。索引假设输入张量是一维的,并按照行优先(C样式)顺序展开。

b.示例:

假设你有一个二维张量,想根据特定的索引列表从中选择元素。下面是如何使用 .take() 来实现这一点的示例:

import torch

# 创建一个二维张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 定义一维索引张量
indices = torch.tensor([0, 4, 8])

# 使用.take()选择元素
selected_elements = tensor.take(indices)

print(selected_elements)

在这个例子中,selected_elements 将包含由 indices 指定的位置的元素。输出将是:

tensor([1, 5, 9])

这里的索引 0, 4, 8 分别对应张量展开后的第1,第5,第9个元素(考虑到从零开始索引)。