Pytorch下张量的形状操作(详细)

发布于:2024-04-26 ⋅ 阅读:(27) ⋅ 点赞:(0)

目录

一、基本操作函数

二、分类:维度改变,张量变形,维度重排

2.1维度改变

2.2张量变形

2.3维度重排

三、实例


一、基本操作函数

在PyTorch中,对张量的形状进行操作是常见的需求,因为它允许我们重新组织、选择和操纵数据,以适应各种模型和函数的需求。以下是一些基本的形状操作函数:

  1. view(): 该方法用于重塑张量。它返回一个新的张量,其数据与原张量相同,但形状不同。你需要保证新形状与原始形状的总元素数相同。

  2. reshape(): 与view()类似,reshape()也可以改变张量的形状。不同之处在于,reshape()可以处理不连续的张量,而view()要求内存中的数据必须是连续的。

  3. squeeze(): 用于去除张量形状中所有的单维度条目,例如将形状为(1, A, 1, B)的张量压缩成(A, B)

  4. unsqueeze(): 在指定位置增加一个尺寸为1的新维度,例如将形状为(A, B)的张量扩展为(1, A, B)(A, 1, B)等。

  5. permute(): 用于重新排列张量的维度。例如,可以将一个形状为(A, B, C)的张量重排为(B, C, A)

  6. transpose(): 用于交换张量的两个维度。通常用于二维张量,但也可以用于多维。

  7. contiguous(): 使张量在内存中连续存储,通常在调用view()之前使用,如果张量在内存中不连续。

  8. size(): 返回张量的形状。

  9. dim(): 返回张量的维度数。

二、分类:维度改变,张量变形,维度重排

2.1维度改变

维度改变指的是增加或减少张量的维度数目。常见的操作有:

  • unsqueeze():在指定的维度处增加一个尺寸为1的新维度,通常用于为已有数据添加批处理维度或其他需要的单独维度。
  • squeeze():去除张量中所有长度为1的维度,或者在指定位置去除单独的长度为1的维度。这常用于去除多余的维度,简化数据结构。

2.2张量变形

张量变形是调整张量内部元素的排列顺序但保持总元素数量不变。这类操作包括:

  • view():重塑张量到一个指定的形状。此操作要求原始数据在内存中连续,如果不连续,通常需要先调用contiguous()
  • reshape():功能与view()相似,但可以自动处理数据的连续性问题。它在不改变数据的总元素数的情况下更改形状。

2.3维度重排

维度重排涉及调整张量的维度顺序,这在处理不同数据格式时特别有用,比如从NCHW转换到NHWC。相关操作包括:

  • transpose():用于交换张量中的两个维度。它特别常用于处理2D数据,如在矩阵转置中。
  • permute():更一般化的维度交换操作,可以一次性重新排序多个维度。这使得它非常灵活,适用于复杂的多维数据重排需求。

三、实例

这里将通过一个简单的Python例子来展示如何在PyTorch中使用上述的张量操作函数。我们将创建一个张量,然后对其进行维度改变、张量变形和维度重排的操作。

假设我们正在处理图像数据,我们有一个表示多个RGB图像的4维张量,形状为(batch_size, channels, height, width)。我们将执行以下步骤:

  1. 增加一个维度来表示时间序列(例如视频帧)。
  2. 将张量展平,以便可以将其用于全连接层。
  3. 将通道置于最后(从NCHW到NHWC格式)。

代码:

import torch

# 创建一个初始张量,形状为 (batch_size, channels, height, width)
batch_size, channels, height, width = 3, 3, 240, 320
x = torch.randn(batch_size, channels, height, width)

# 增加一个时间维度,假设每个批次有5帧
time_steps = 5
x = x.unsqueeze(1)  # 在第二个维度处增加
x = x.expand(-1, time_steps, -1, -1, -1)  # 将新维度扩展到5

# 输出增加时间维度后的张量形状
print("Shape after adding time dimension:", x.shape)

# 交换维度,将通道从第三位置移到最后
x = x.permute(0, 1, 3, 4, 2)  # 结果的形状将是(batch_size, time_steps, height, width, channels)
print("Shape after permuting:", x.shape)

# 展平张量,除批次和时间维度外
x = x.reshape(batch_size, time_steps, -1)  # -1会自动计算需要的大小
print("Shape after flattening:", x.shape)

说明:

  • 首先,我们创建了一个随机的张量x,代表了一个批次中的多个RGB图像。
  • 接着,我们在unsqueeze()中增加了一个时间维度,并用expand()方法填充这个维度,模拟一个时间序列数据。
  • 然后,我们用reshape()方法将除时间和批次外的其他维度合并,为后续的神经网络层准备。
  • 最后,我们使用permute()重新排列维度,将通道放到最后,这对某些图像处理库更为友好。

结果:

  • 增加时间维度后:形状是(3, 5, 3, 240, 320),表示有3个批次,每批有5帧,每帧3个通道,每通道240x320像素。
  • 交换维度后:形状是(3, 5, 240, 320, 3),其中通道被移到了最后。
  • 展平操作后:形状是(3, 5, 230400),表示每批每帧的所有像素值和通道都被展平。