在 PyTorch 中,flatten()
函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。
一、基本语法
torch.flatten(input, start_dim=0, end_dim=-1)
参数说明:
参数 | 说明 |
---|---|
input |
输入张量 |
start_dim |
开始展平的维度(包含该维) |
end_dim |
结束展平的维度(包含该维) |
展平操作会把
start_dim
到end_dim
之间的维度合并成一维。
二、常见示例
示例 1:基本使用
import torch
x = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]]) # shape = (2, 2, 2)
out = torch.flatten(x)
print(out)
print(out.shape) # torch.Size([8])
等价于 x.view(-1)
,即将所有维度展平成一维。
示例 2:保留前维度(常见于 CNN)
x = torch.randn(10, 3, 32, 32) # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)
print(out.shape) # torch.Size([10, 3072])
解释:
- 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
- 第 0 维(batch size)保留,适合连接到
nn.Linear
层
示例 3:多维展开(指定 end_dim)
x = torch.randn(2, 3, 4, 5) # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)
print(out.shape) # torch.Size([2, 12, 5]) -> (3*4 = 12)
三、与 .view()
的区别
函数 | 说明 |
---|---|
view() |
更底层、需要张量是连续的,手动指定形状 |
flatten() |
更高层、更安全、自动处理维度合并,常用于模型构建中 |
四、常见用法:在模型中使用
1、示例1
import torch.nn as nn
class MyCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv(x)
x = self.pool(x) # shape: (N, 16, 1, 1)
x = torch.flatten(x, 1) # shape: (N, 16)
x = self.fc(x)
return x
2、示例2
下面使用了 torch.flatten()
将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。
使用 flatten()
的 CNN 训练流程(以 CIFAR-10 为例):
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):
def __init__(self):
super(FlattenCNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), # 输入: [B, 3, 32, 32]
nn.ReLU(),
nn.MaxPool2d(2), # 输出: [B, 16, 16, 16]
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2) # 输出: [B, 32, 8, 8]
)
self.fc = nn.Sequential(
nn.Linear(32 * 8 * 8, 128),
nn.ReLU(),
nn.Linear(128, 10) # CIFAR-10 共 10 类
)
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x, 1) # 👈 仅展平通道和空间维度,保留 batch
x = self.fc(x)
return x
# ==== 2. 准备数据 ====
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# ==== 4. 训练过程 ====
def train(model, loader, epochs):
model.train()
for epoch in range(epochs):
total_loss = 0.0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")
# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)
重点说明
使用 torch.flatten(x, 1)
的原因:
- 只展平通道、高、宽三维(保留 batch size)
- 替代
x.view(x.size(0), -1)
更安全,避免非连续张量报错 - 推荐在模型中构建更加模块化、清晰
五、三种张量展平方式:flatten()
、view()
和 reshape()
的对比
下面从功能差异、使用限制和**性能对比(benchmark)**进行三者的比较。
1、三者功能对比
函数 | 特点说明 |
---|---|
flatten() |
高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。 |
view() |
底层操作,速度快,但要求张量是连续(tensor.is_contiguous() 为 True ) |
reshape() |
更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全 |
2、代码功能对比
x = torch.randn(32, 3, 64, 64) # batch of images
# flatten
f1 = torch.flatten(x, 1)
# view
f2 = x.view(32, -1)
# reshape
f3 = x.reshape(32, -1)
print(f1.shape, f2.shape, f3.shape)
输出一致:torch.Size([32, 12288])
3、非连续张量对比(view 会报错)
x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1) # 非连续张量
try:
y.view(-1) # 会报错
except RuntimeError as e:
print("view error:", e)
print("reshape:", y.reshape(-1).shape) # reshape 正常
print("flatten:", torch.flatten(y).shape) # flatten 正常
4、性能测试(benchmark)
import torch
import time
x = torch.randn(1024, 512, 28, 28)
# 保证是连续的
x_contig = x.contiguous()
N = 1000
def benchmark(op, name):
torch.cuda.synchronize()
start = time.time()
for _ in range(N):
_ = op(x_contig)
torch.cuda.synchronize()
end = time.time()
print(f"{name}: {(end - start)*1000:.2f} ms")
benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")
示例结果(A100 GPU):
flatten(): 58.12 ms
view(): 41.76 ms
reshape(): 47.32 ms
总结:
view()
最快,但要求张量连续;flatten()
最安全但稍慢;reshape()
是折中方案。
5、 建议总结
场景 | 推荐方式 | 原因 |
---|---|---|
模型中展平 CNN 输出 | flatten() |
简洁、安全,尤其在复杂网络中 |
确保连续张量、追求速度 | view() |
性能最佳 |
张量可能非连续 | reshape() |
自动处理不连续情况,代码更鲁棒 |
六、小结
用法 | 效果 |
---|---|
torch.flatten(x) |
将所有维展平成一维 |
torch.flatten(x, 1) |
保留 batch 维,常用于 CNN |
torch.flatten(x, 1, 2) |
展平指定维度区间 |