知识点回顾:
- resnet结构解析
- CBAM放置位置的思考
- 针对预训练模型的训练策略
- 差异化学习率
- 三阶段微调
作业:
- 好好理解下resnet18的模型结构
- 尝试对vgg16+cbam进行微调策略
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import copy
# 定义CBAM模块
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
return self.sigmoid(out)
class CBAM(nn.Module):
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_att = ChannelAttention(in_channels, reduction_ratio)
self.spatial_att = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.channel_att(x)
x = x * self.spatial_att(x)
return x
# 修改VGG16模型,插入CBAM模块
class VGG16_CBAM(nn.Module):
def __init__(self, num_classes=1000, pretrained=True):
super(VGG16_CBAM, self).__init__()
# 加载预训练的VGG16
vgg16 = models.vgg16(pretrained=pretrained)
self.features = vgg16.features
# 在每个MaxPool2d后插入CBAM模块
new_features = []
cbam_idx = 0
for module in self.features:
new_features.append(module)
if isinstance(module, nn.MaxPool2d):
# 不在第一个MaxPool后添加CBAM
if cbam_idx > 0:
in_channels = list(module.parameters())[0].shape[1]
new_features.append(CBAM(in_channels))
cbam_idx += 1
self.features = nn.Sequential(*new_features)
self.avgpool = vgg16.avgpool
self.classifier = vgg16.classifier
# 修改最后一层以适应指定的类别数
if num_classes != 1000:
self.classifier[-1] = nn.Linear(self.classifier[-1].in_features, num_classes)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# 三阶段微调策略
def train_model_three_phase(model, dataloaders, criterion, device, num_epochs=25):
# 第一阶段:冻结所有层,只训练分类器
print("第一阶段:只训练分类器")
for param in model.parameters():
param.requires_grad = False
# 解冻分类器参数
for param in model.classifier.parameters():
param.requires_grad = True
optimizer = optim.SGD(model.classifier.parameters(), lr=0.001, momentum=0.9)
model = train_one_phase(model, dataloaders, criterion, optimizer, device, num_epochs=5)
# 第二阶段:解冻部分层 + 分类器,使用差异化学习率
print("\n第二阶段:解冻部分层并使用差异化学习率")
# 解冻最后两个特征块和CBAM模块
for i in range(24, len(model.features)):
for param in model.features[i].parameters():
param.requires_grad = True
# 为不同层设置不同的学习率
params_to_update = []
# 特征部分学习率低
params_to_update.append({
'params': [param for param in model.features.parameters() if param.requires_grad],
'lr': 0.0001
})
# 分类器部分学习率高
params_to_update.append({
'params': model.classifier.parameters(),
'lr': 0.001
})
optimizer = optim.SGD(params_to_update, momentum=0.9)
model = train_one_phase(model, dataloaders, criterion, optimizer, device, num_epochs=10)
# 第三阶段:解冻所有层,使用低学习率微调整个网络
print("\n第三阶段:微调整个网络")
for param in model.parameters():
param.requires_grad = True
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
model = train_one_phase(model, dataloaders, criterion, optimizer, device, num_epochs=10)
return model
# 辅助函数:执行一个阶段的训练
def train_one_phase(model, dataloaders, criterion, optimizer, device, num_epochs=5):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
model.to(device)
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs-1}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
print(f'Best val Acc: {best_acc:4f}')
model.load_state_dict(best_model_wts)
return model
# 数据加载和预处理
def load_data(data_dir):
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(data_dir + x, data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
return dataloaders
# 主函数
def main():
# 假设数据目录结构为:data/train/ 和 data/val/
data_dir = "data/"
dataloaders = load_data(data_dir)
# 创建模型
model = VGG16_CBAM(num_classes=2, pretrained=True)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 执行三阶段微调
model_ft = train_model_three_phase(model, dataloaders, criterion, device)
# 保存模型
torch.save(model_ft.state_dict(), 'vgg16_cbam_finetuned.pth')
if __name__ == "__main__":
main()