pytorch nn.Unflatten 和 nn.Flatten模块介绍

发布于:2025-02-11 ⋅ 阅读:(167) ⋅ 点赞:(0)

nn.Flatten 和 nn.Unflatten 是 PyTorch 中用于调整张量形状的模块。它们提供了对多维张量的简单变换,常用于神经网络模型的层之间的数据调整。


1. nn.Flatten

功能:

  • 将输入张量展平为二维张量,通常用于将卷积层的输出展平成全连接层的输入。
  • 它会将张量的指定维度范围压缩为单个维度。

构造参数:

  • start_dim: 展平的起始维度(默认值为 1)。
  • end_dim: 展平的结束维度(默认值为 -1)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, channels, height, width]
x = torch.randn(4, 3, 32, 32)

# 展平操作
flatten = nn.Flatten(start_dim=1)  # 从维度1到最后展平
y = flatten(x)

print(y.shape)  # 输出: [4, 3072] (3*32*32 被展平)

适用场景:

  • 通常用于从卷积层(或其他多维特征)到全连接层的过渡。
  • 例如:[batch_size, channels, height, width] -> [batch_size, features]

2. nn.Unflatten

功能:

  • 将展平的张量还原为多维张量。
  • 它通过指定目标维度和形状信息,反向操作 nn.Flatten

构造参数:

  • dim: 需要展开的维度。
  • unflattened_size: 展开的形状(tuple 类型)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, features]
x = torch.randn(4, 3072)

# 还原操作
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))
y = unflatten(x)

print(y.shape)  # 输出: [4, 3, 32, 32]

适用场景:

  • 通常用于从全连接层(或展平特征)还原到卷积层或其他多维表示。
  • 例如:[batch_size, features] -> [batch_size, channels, height, width]

对比

特性 nn.Flatten nn.Unflatten
主要操作 将多个维度压缩为一个维度 将一个维度展开为多个维度
输入 多维张量 展平的张量
输出 二维张量 恢复为多维张量
常用场景 用于连接卷积层和全连接层 用于从展平的特征恢复到多维结构
参数控制 指定展平的起始和结束维度范围 指定需要展开的维度和目标形状

实际应用示例

结合使用 Flatten 和 Unflatten:

import torch
from torch import nn

# 初始化 Flatten 和 Unflatten
flatten = nn.Flatten(start_dim=1)
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))

# 模拟数据
x = torch.randn(4, 3, 32, 32)  # [batch_size, channels, height, width]

# 展平
flat_x = flatten(x)
print(flat_x.shape)  # 输出: [4, 3072]

# 恢复
unflat_x = unflatten(flat_x)
print(unflat_x.shape)  # 输出: [4, 3, 32, 32]

这两个模块通过简单的接口提供了灵活的形状调整功能,是构建神经网络过程中不可或缺的工具。


网站公告

今日签到

点亮在社区的每一天
去签到