ResNet 网络教学文章
一、背景介绍
在深度学习发展早期,研究者发现:更深的神经网络往往拥有更强的拟合能力。但在实际实验中,当网络层数不断加深时,效果却出现了瓶颈甚至退化现象:
梯度消失/爆炸:反向传播时梯度越来越小或越来越大,导致训练困难。
退化问题:在相同训练数据和优化策略下,增加网络深度反而导致训练误差升高。
针对这一问题,微软研究院的 Kaiming He 等人提出了 残差网络(Residual Network,ResNet),并在 2015 年的 ImageNet 比赛中夺得冠军。ResNet 的核心思想是 残差学习(Residual Learning),通过引入 快捷连接(skip connection) 来缓解梯度消失和退化问题,从而让深层网络训练成为可能。
二、核心思想:残差学习
传统的深层神经网络可以表示为:
其中,H(x)H(x) 是期望学习的映射,F(x)F(x) 是堆叠层实现的函数
ResNet 的创新点是:与其直接学习 H(x)H(x),不如让网络去学习 残差函数:
也就是说,网络学习的不是直接的输出,而是输入和输出的差异(残差)。
这种设计有两个好处:
如果最优映射接近恒等函数,学习残差更容易。
快捷连接保证了信息和梯度可以直接跨层传播,避免梯度消失。
三、ResNet 的基本结构
1. 残差模块(Residual Block)
残差模块是 ResNet 的核心。其基本结构为:
两层卷积(Conv-BN-ReLU),负责提取特征。
一条 shortcut connection,将输入直接加到输出上。
公式:
结构图示:
输入 x
│
[卷积-BN-ReLU]
│
[卷积-BN]
│
+──────→ 输出 y
│
输入 x
2. 瓶颈结构(Bottleneck Block)
在更深层的 ResNet(如 ResNet-50/101/152)中,引入了 瓶颈结构 来减少计算量:
先用 1×1 卷积降维。
再用 3×3 卷积进行特征提取。
最后用 1×1 卷积升维。
这样既能减少参数量,又能保持特征表达能力。
四、经典 ResNet 网络结构
常见的 ResNet 网络深度有 18、34、50、101、152 层。其区别主要在于残差模块的堆叠数量和是否使用瓶颈结构。
以 ResNet-18 为例,整体结构:
Conv1:7×7 卷积 + 最大池化
Conv2_x:2 个残差模块
Conv3_x:2 个残差模块
Conv4_x:2 个残差模块
Conv5_x:2 个残差模块
全局平均池化 + 全连接层
ResNet-50、101、152 则采用瓶颈残差模块,并堆叠更多层数。
五、PyTorch 实现示例
下面给出一个简化版的 ResNet 实现代码(以 ResNet-18 为例):
import torch
import torch.nn as nn
import torch.nn.functional as F
# 基本残差模块
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample # 用于通道数不一致时的调整
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
# ResNet 主体
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def ResNet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
# 测试
if __name__ == "__main__":
model = ResNet18(num_classes=10)
print(model)
x = torch.randn(1, 3, 224, 224)
y = model(x)
print(y.shape)
六、ResNet 的应用
ResNet 已成为计算机视觉任务中的经典网络结构,被广泛应用于:
图像分类(ImageNet、CIFAR-10)
目标检测(Faster R-CNN, YOLO 的 backbone)
图像分割(Mask R-CNN, U-Net 的变体)
医学影像分析(CT/MRI 分类、分割)
语音和自然语言处理(Transformer 前的深层特征提取)
七、总结
ResNet 的核心思想:通过残差学习和快捷连接解决深层网络的退化问题。
残差模块:基础块(BasicBlock)和瓶颈块(Bottleneck)。
效果:ResNet 让网络层数从几十层突破到上百层,极大推动了深度学习的发展。
应用:几乎所有计算机视觉主流模型都与 ResNet 有关。