Pytorch中view函数详解和工程实战示例

发布于:2025-06-13 ⋅ 阅读:(20) ⋅ 点赞:(0)

在 PyTorch 中,view() 是一个非常常用的张量(Tensor)操作函数,用于 改变张量的形状(shape),但 不会改变其数据内容。你可以把它看作是 PyTorch 中的类似 NumPy 中 reshape() 的方法。


一、基本语法

tensor.view(shape)
  • shape:目标张量的形状,可以是多个整数参数,也可以是一个 tuple。
  • 其中某个维度可以设为 -1,PyTorch 会根据总元素数自动推断这个维度的大小。

二、注意事项

  1. view() 要求 原张量是连续的内存(contiguous),否则需先 .contiguous()
  2. view() 改变的是张量的“视图”,不拷贝数据,效率高;
  3. view() 返回的是一个新的张量,原张量不变;
  4. -1 只能出现一次,用于自动计算维度。

三、代码示例

示例 1:基本使用

import torch

a = torch.arange(12)         # 创建一个 1D 张量 [0, 1, ..., 11]
print(a.shape)               # torch.Size([12])

b = a.view(3, 4)             # 变成 3x4 的二维张量
print(b)

输出:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

示例 2:使用 -1 自动推断维度

a = torch.arange(12)
b = a.view(3, -1)            # 自动计算出第二个维度为 4
print(b.shape)               # torch.Size([3, 4])

示例 3:使用 view() 时不连续的内存

a = torch.randn(4, 3)
b = a.transpose(0, 1)        # 维度交换后,b 不是连续内存
try:
    b.view(-1)               # 报错
except RuntimeError as e:
    print("错误:", e)

# 正确做法:先调用 contiguous()
b = b.contiguous().view(-1)

示例 4:结合 batch 维度 reshape 图像

img = torch.randn(16, 3, 32, 32)     # 假设是一个 batch 的图像 (N=16, C=3, H=32, W=32)
flat = img.view(16, -1)              # 展平成向量,每个图像是 3072 维
print(flat.shape)                    # torch.Size([16, 3072])

四、view() vs reshape()

a.view(3, 4)
a.reshape(3, 4)
  • 相同点:功能类似,都能改变形状;

  • 不同点:

    • view() 要求连续内存;
    • reshape() 更灵活,自动处理非连续情况,内部可能会拷贝数据。

推荐:如果你确信内存连续,view() 更高效;否则用 reshape() 更稳妥


五、工程实战应用示例

在使用 CNN(卷积神经网络)或 RNN(循环神经网络)时,view() 函数常被用于 调整张量的形状 以满足网络结构输入/输出的要求。以下是典型场景中的用法分析和代码示例:

1、CNN 中 view() 的常见用途

场景:Flatten 卷积层输出,接入全连接层(Linear)

卷积层输出通常是一个 4D 张量 (batch_size, channels, height, width),在接入 nn.Linear 全连接层前需要展平为 2D,即 (batch_size, features)

示例:
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3),  # 输入:1x28x28,输出:16x26x26
            nn.ReLU(),
            nn.MaxPool2d(2)       # 输出:16x13x13
        )
        self.fc = nn.Linear(16 * 13 * 13, 10)

    def forward(self, x):
        x = self.conv(x)             # shape: (batch, 16, 13, 13)
        x = x.view(x.size(0), -1)    # shape: (batch, 2704)
        x = self.fc(x)
        return x

x = torch.randn(8, 1, 28, 28)  # batch of 8 images
model = CNN()
output = model(x)
print(output.shape)  # torch.Size([8, 10])

x.view(x.size(0), -1) 保证了展平时 batch size 不变,特征自动推断。


2、RNN 中 view() 的常见用途

场景 1:将 RNN 的输出展平后送入全连接层

RNN 输出形状通常是 (seq_len, batch_size, hidden_size),有时需要 reshape 为 (batch_size * seq_len, hidden_size) 再处理。

示例:
class RNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(input_size=10, hidden_size=20)
        self.fc = nn.Linear(20, 5)

    def forward(self, x):
        out, _ = self.rnn(x)           # out: (seq_len, batch_size, 20)
        out = out.view(-1, 20)         # flatten 所有时间步,shape: (seq_len * batch_size, 20)
        out = self.fc(out)             # shape: (seq_len * batch_size, 5)
        return out

x = torch.randn(15, 4, 10)  # seq_len=15, batch_size=4, input_size=10
model = RNNClassifier()
output = model(x)
print(output.shape)  # torch.Size([60, 5])

场景 2:将嵌套序列 batch + time 展平为单个输入批

有时会将输入 (batch, seq_len, input_dim) reshape 为 (batch * seq_len, input_dim) 以方便线性层处理:

x = torch.randn(32, 10, 100)           # batch_size=32, seq_len=10, input_dim=100
x = x.view(-1, 100)                    # -> (320, 100)

3、CNN + RNN 混合模型中的 view() 用法

比如:图像经过 CNN 得到特征,再按时间顺序输入 RNN。这里就需要在 CNN 输出后 reshape 成 RNN 接受的格式 (seq_len, batch, input_size)

示例:
class CNN_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3),  # 1x28x28 -> 8x26x26
            nn.ReLU(),
            nn.MaxPool2d(2)                 # -> 8x13x13
        )
        self.rnn = nn.GRU(13 * 13, 64)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        batch_size, time_steps, C, H, W = x.shape  # e.g., (16, 5, 1, 28, 28)
        x = x.view(batch_size * time_steps, C, H, W)
        x = self.cnn(x)                            # -> (batch*time, 8, 13, 13)
        x = x.view(batch_size, time_steps, -1)     # -> (batch, time, 1352)
        x = x.permute(1, 0, 2)                     # -> (time, batch, features)
        out, _ = self.rnn(x)
        out = self.fc(out[-1])
        return out

x = torch.randn(16, 5, 1, 28, 28)  # batch=16, time=5
model = CNN_RNN()
output = model(x)
print(output.shape)  # torch.Size([16, 10])

总结:常见 view() 用法对照表

用法场景 示例代码 解释
Flatten CNN 输出 x.view(x.size(0), -1) 展平成 (B, C*H*W)
展开 RNN 所有时间步输出 x.view(-1, hidden_size) -> (T*B, H)
输入 RNN 前展开为 (seq, B, dim) x.permute(1, 0, 2) 时间维放前面
多帧图像展平成一个 batch x.view(B*T, C, H, W) 常用于视频输入 CNN

总结

特性 view()
功能 改变张量形状
是否拷贝数据 否(更高效)
内存要求 连续(需 .contiguous())
支持 -1 吗 支持自动推断