Pytorch中张量的索引和切片使用详解和代码示例

发布于:2025-07-16 ⋅ 阅读:(13) ⋅ 点赞:(0)

PyTorch 中张量索引与切片详解

使用前先导入:

import torch

1.基础索引(类似 Python / NumPy)

适用于低维张量:x[i]x[i, j]

x = torch.tensor([[10, 11, 12],
                  [13, 14, 15],
                  [16, 17, 18]])

print(x[0])         # 第0行: tensor([10, 11, 12])
print(x[1][2])      # 第1行第2列: 15
print(x[2, 1])      # 第2行第1列: 17

2.切片(Slicing)

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])

print(x[:2])        # 前两行
print(x[:, 1:3])    # 所有行,第1~2列
print(x[::2, ::2])  # 行列间隔为2

3.负索引

print(x[-1])        # 最后一行
print(x[:, -2:])    # 每行最后两列

4.使用 ... (Ellipsis)

当维度很多时可简化操作。

x = torch.arange(2*3*4).reshape(2, 3, 4)

# 等价于 x[0, :, 2]
print(x[0, ..., 2])

5.Noneunsqueeze 增加维度

x = torch.tensor([1, 2, 3])

# 增加维度(等价于 unsqueeze)
print(x[None, :].shape)     # torch.Size([1, 3])
print(x[:, None].shape)     # torch.Size([3, 1])

6. 布尔索引(Boolean Indexing)

x = torch.tensor([10, 20, 30, 40])

mask = x > 25
print(mask)         # tensor([False, False,  True,  True])
print(x[mask])      # tensor([30, 40])

7. 花式索引(Fancy Indexing)

使用索引列表访问多个非连续位置。

x = torch.tensor([10, 20, 30, 40, 50])

idx = torch.tensor([0, 2, 4])
print(x[idx])       # tensor([10, 30, 50])

二维花式索引:

x = torch.arange(1, 10).reshape(3, 3)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

rows = torch.tensor([0, 1, 2])
cols = torch.tensor([2, 1, 0])
print(x[rows, cols])  # [3, 5, 7]

8. 条件赋值 / where

x = torch.tensor([1, 2, 3, 4, 5])
x[x > 3] = 100
print(x)            # tensor([  1,   2,   3, 100, 100])

# 条件选择
a = torch.tensor([1, 2, 3])
b = torch.tensor([10, 20, 30])
cond = torch.tensor([True, False, True])

print(torch.where(cond, a, b))  # -> [1, 20, 3]

9. 高维张量索引技巧

x = torch.arange(2*3*4).reshape(2, 3, 4)

# 提取第1个 batch 所有通道第2列
print(x[0, :, 2])    # shape: (3,)

10. 实例:图像张量裁剪(HWC)

img = torch.rand((3, 256, 256))  # C, H, W 格式

# 裁剪中心区域
crop = img[:, 100:200, 100:200]  # shape (3, 100, 100)

11. 总结图解(结构化索引方式)

张量索引方式:
├── 基础索引(x[i], x[i,j])
├── 切片(x[start:end], x[:, idx])
├── 高维省略(x[..., -1])
├── 增维/降维(x[None, :], x.squeeze())
├── 布尔索引(x[x>val])
├── 花式索引(x[[0, 2, 4]])
├── 条件赋值(x[x > a] = b)
└── torch.where(cond, a, b)

高级应用


1. 高级花式索引(Advanced Fancy Indexing)

基本复习:

花式索引是用整张或部分张量作为索引,获取非连续元素。进阶里,张量的形状组合、广播规则非常重要。

代码示例:

import torch

x = torch.arange(27).reshape(3, 3, 3)
# x shape = (3, 3, 3)

# 目标:同时选取不同 batch 不同通道的元素
idx_batch = torch.tensor([0, 1, 2])   # 每个 batch 索引
idx_channel = torch.tensor([2, 1, 0]) # 每个对应通道索引
idx_row = torch.tensor([0, 1, 2])     # 对应行索引

# 三个索引张量自动广播,选出:
# x[0, 2, 0], x[1, 1, 1], x[2, 0, 2]
result = x[idx_batch, idx_channel, idx_row]

print(result)  # tensor([ 6, 13, 24])
  • 关键是各个索引张量形状要匹配或可广播
  • 返回值的形状取决于索引张量的形状。

2. 坐标映射索引(Indexing with Coordinate Tensors)

常用在点云、图像坐标映射,手工给定索引位置批量取值。

代码示例:

x = torch.arange(16).reshape(4, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])

# 给定坐标点
coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 三个点的坐标

rows = coords[:, 0]
cols = coords[:, 1]

vals = x[rows, cols]
print(vals)  # tensor([ 1, 11, 12])

torch.gather — 按索引沿指定维度收集数据

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

indices = torch.tensor([[0, 3], [2, 1], [1, 0]])
result = torch.gather(x, dim=1, index=indices)
print(result)
# tensor([[ 0,  3],
#         [ 6,  5],
#         [ 9,  8]])
  • torch.gather 需要索引张量与输入同形状,但索引值表示该维度的选取位置。

3. 高维图像张量处理技巧

假设图像张量格式为 (Batch, Channels, Height, Width),称为 BCHW。

常用操作示例:

(a) 批量裁剪 (Crop)
img = torch.randn(5, 3, 256, 256)  # 5张RGB图像

# 取中心128x128块
h_start = (256 - 128) // 2
w_start = (256 - 128) // 2

crop = img[:, :, h_start:h_start+128, w_start:w_start+128]  # shape (5, 3, 128, 128)
(b) 改变通道顺序
# BCHW -> BHWC
img_bhwc = img.permute(0, 2, 3, 1)
print(img_bhwc.shape)  # (5, 256, 256, 3)
© 按坐标索引批量像素点
batch_size = 2
img = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)

# 取每张图(0,1)通道,指定像素点坐标
coords = torch.tensor([[1, 2], [3, 0]])  # (batch_size, 2) 像素坐标 (H, W)

batch_indices = torch.arange(batch_size)
channels = torch.tensor([0, 1])  # 不同图不同通道

pixels = img[batch_indices, channels, coords[:, 0], coords[:, 1]]
print(pixels)

总结:

技巧类别 适用场景 关键函数/概念
高级花式索引 多维非连续索引,索引张量广播 多张量索引广播
坐标映射索引 点云坐标、图像点批量索引 torch.gather, 坐标张量索引
高维图像张量处理 批量裁剪、通道转换、批量像素选取 permutereshape、多维切片

4.综合示例

下面以一个综合示例代码,涵盖 高级花式索引坐标映射索引,以及 高维图像张量处理,注释详尽,方便大家理解和直接跑起来。

import torch

def advanced_fancy_indexing():
    print("=== 高级花式索引示例 ===")
    x = torch.arange(27).reshape(3, 3, 3)
    idx_batch = torch.tensor([0, 1, 2])
    idx_channel = torch.tensor([2, 1, 0])
    idx_row = torch.tensor([0, 1, 2])
    # 选出 x[0,2,0], x[1,1,1], x[2,0,2]
    result = x[idx_batch, idx_channel, idx_row]
    print(result)  # tensor([ 6, 13, 24])
    print()

def coordinate_mapping_indexing():
    print("=== 坐标映射索引示例 ===")
    x = torch.arange(16).reshape(4, 4)
    coords = torch.tensor([[0, 1], [2, 3], [3, 0]])  # 3个坐标点
    rows = coords[:, 0]
    cols = coords[:, 1]
    vals = x[rows, cols]
    print(f"从坐标 {coords.tolist()} 取值: {vals.tolist()}")

    # torch.gather示例
    x2 = torch.arange(12).reshape(3, 4)
    indices = torch.tensor([[0, 3], [2, 1], [1, 0]])
    gathered = torch.gather(x2, dim=1, index=indices)
    print(f"torch.gather 结果:\n{gathered}")
    print()

def high_dim_image_tensor_processing():
    print("=== 高维图像张量处理示例 ===")
    # 生成一个 5张RGB图像 BCHW 格式
    img = torch.randn(5, 3, 256, 256)

    # 裁剪中心128x128
    h_start = (256 - 128) // 2
    w_start = (256 - 128) // 2
    crop = img[:, :, h_start:h_start+128, w_start:w_start+128]
    print(f"裁剪后的形状: {crop.shape}")

    # 通道顺序变换 BCHW -> BHWC
    img_bhwc = img.permute(0, 2, 3, 1)
    print(f"通道转换后形状: {img_bhwc.shape}")

    # 批量取像素点
    batch_size = 2
    img_small = torch.arange(batch_size*3*4*4).reshape(batch_size, 3, 4, 4)

    coords = torch.tensor([[1, 2], [3, 0]])  # 每张图像的像素坐标 (H, W)
    batch_indices = torch.arange(batch_size)
    channels = torch.tensor([0, 1])  # 两张图不同通道

    pixels = img_small[batch_indices, channels, coords[:, 0], coords[:, 1]]
    print(f"批量像素值: {pixels.tolist()}")

if __name__ == "__main__":
    advanced_fancy_indexing()
    coordinate_mapping_indexing()
    high_dim_image_tensor_processing()

代码说明

  • advanced_fancy_indexing()
    演示多张量广播索引从三维张量中选取不规则元素。

  • coordinate_mapping_indexing()
    演示给定坐标点批量取值 + 用 torch.gather 沿某维度收集。

  • high_dim_image_tensor_processing()
    展示了高维图像张量裁剪、通道排列变换和批量像素点采样。



网站公告

今日签到

点亮在社区的每一天
去签到