day43对自己找的数据集用简单cnn训练,现在用预训练,加入注意力等
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 1. 数据预处理
# # 计算均值和方差(仅运行一次)
# def calculate_mean_std(dataloader):
# mean = torch.zeros(3)
# std = torch.zeros(3)
# total_images = 0
# for images, _ in dataloader:
# batch_size = images.size(0)
# images = images.view(batch_size, 3, -1)
# mean += images.mean(2).sum(0)
# std += images.std(2).sum(0)
# total_images += batch_size
# mean /= total_images
# std /= total_images
# return mean, std
# # 用无增强的dataloader计算(避免增强影响统计)
# temp_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
# temp_dataset = datasets.ImageFolder(root=your_data_root, transform=temp_transform)
# temp_loader = DataLoader(temp_dataset, batch_size=32, shuffle=False)
# mean, std = calculate_mean_std(temp_loader)
# print(f"数据集均值:{mean},方差:{std}")
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # 使用ImageNet的均值和方差
transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])
# 2. 加载自定义数据集
full_dataset = datasets.ImageFolder(
root=r"BengaliFishImages\fish_images",
transform=train_transform
)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
# 3. 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 4. 定义注意力机制
# SE注意力机制模块
class SEBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# CBAM注意力机制模块
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class CBAMBlock(nn.Module):
def __init__(self, channel, ratio=16, kernel_size=7):
super(CBAMBlock, self).__init__()
self.channel_attention = ChannelAttention(channel, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.channel_attention(x)
x = x * self.spatial_attention(x)
return x
# 5. 定义改进的CNN模型(可选择添加SE或CBAM注意力)
class ImprovedCNN(nn.Module):
def __init__(self, num_classes=20, attention_type=None):
super(ImprovedCNN, self).__init__()
self.attention_type = attention_type
# 第一个卷积块
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2) # 128 -> 64
if attention_type == 'se':
self.att1 = SEBlock(32)
elif attention_type == 'cbam':
self.att1 = CBAMBlock(32)
# 第二个卷积块
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2) # 64 -> 32
if attention_type == 'se':
self.att2 = SEBlock(64)
elif attention_type == 'cbam':
self.att2 = CBAMBlock(64)
# 第三个卷积块
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(2) # 32 -> 16
if attention_type == 'se':
self.att3 = SEBlock(128)
elif attention_type == 'cbam':
self.att3 = CBAMBlock(128)
# 第四个卷积块
self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.relu4 = nn.ReLU()
self.pool4 = nn.MaxPool2d(2) # 16 -> 8
if attention_type == 'se':
self.att4 = SEBlock(256)
elif attention_type == 'cbam':
self.att4 = CBAMBlock(256)
# 全连接层
self.fc1 = nn.Linear(256 * 8 * 8, 512)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
# 卷积块 1
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
if self.attention_type is not None:
x = self.att1(x)
# 卷积块 2
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
if self.attention_type is not None:
x = self.att2(x)
# 卷积块 3
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.pool3(x)
if self.attention_type is not None:
x = self.att3(x)
# 卷积块 4
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
x = self.pool4(x)
if self.attention_type is not None:
x = self.att4(x)
# 全连接层
x = x.view(-1, 256 * 8 * 8)
x = self.fc1(x)
x = self.relu4(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# 6. 基于预训练模型的分类器
def create_pretrained_model(model_name, num_classes=20, freeze_feature=True, attention_type=None):
"""
创建基于预训练模型的分类器
Args:
model_name: 预训练模型名称,如'resnet50', 'vgg16', 'mobilenet_v2'
num_classes: 分类类别数
freeze_feature: 是否冻结特征提取部分
attention_type: 注意力类型,None, 'se' 或 'cbam'
Returns:
构建好的模型
"""
if model_name == 'resnet50':
model = models.resnet50(pretrained=True)
# 冻结特征提取部分
if freeze_feature:
for param in model.parameters():
param.requires_grad = False
# 添加注意力机制(可选)
if attention_type == 'se':
model.layer4[0].conv1 = nn.Sequential(
model.layer4[0].conv1,
SEBlock(512)
)
elif attention_type == 'cbam':
model.layer4[0].conv1 = nn.Sequential(
model.layer4[0].conv1,
CBAMBlock(512)
)
# 替换最后的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
elif model_name == 'vgg16':
model = models.vgg16(pretrained=True)
if freeze_feature:
for param in model.features.parameters():
param.requires_grad = False
# 添加注意力机制(可选)
if attention_type is not None:
att_module = SEBlock(512) if attention_type == 'se' else CBAMBlock(512)
model.features = nn.Sequential(
*list(model.features.children()),
att_module
)
# 替换分类器
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
elif model_name == 'mobilenet_v2':
model = models.mobilenet_v2(pretrained=True)
if freeze_feature:
for param in model.features.parameters():
param.requires_grad = False
# 添加注意力机制(可选)
if attention_type is not None:
att_module = SEBlock(1280) if attention_type == 'se' else CBAMBlock(1280)
model.features = nn.Sequential(
*list(model.features.children()),
att_module
)
# 替换分类器
num_ftrs = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
else:
raise ValueError(f"不支持的模型名称: {model_name}")
return model
# 7. 训练与测试函数(保持原有功能,略作调整)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):
model.train()
all_iter_losses = []
iter_indices = []
train_acc_history = []
test_acc_history = []
train_loss_history = []
test_loss_history = []
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
iter_loss = loss.item()
all_iter_losses.append(iter_loss)
iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
running_loss += iter_loss
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if (batch_idx + 1) % 100 == 0:
print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
epoch_train_loss = running_loss / len(train_loader)
epoch_train_acc = 100. * correct / total
train_acc_history.append(epoch_train_acc)
train_loss_history.append(epoch_train_loss)
# 测试阶段
model.eval()
test_loss = 0
correct_test = 0
total_test = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
_, predicted = output.max(1)
total_test += target.size(0)
correct_test += predicted.eq(target).sum().item()
epoch_test_loss = test_loss / len(test_loader)
epoch_test_acc = 100. * correct_test / total_test
test_acc_history.append(epoch_test_acc)
test_loss_history.append(epoch_test_loss)
scheduler.step(epoch_test_loss)
print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
plot_iter_losses(all_iter_losses, iter_indices)
plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
return epoch_test_acc
# 8. 绘图函数(保持不变)
def plot_iter_losses(losses, indices):
plt.figure(figsize=(10, 4))
plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
plt.xlabel('Iteration(Batch序号)')
plt.ylabel('损失值')
plt.title('每个 Iteration 的训练损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
epochs = range(1, len(train_acc) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_acc, 'b-', label='训练准确率')
plt.plot(epochs, test_acc, 'r-', label='测试准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率 (%)')
plt.title('训练和测试准确率')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs, train_loss, 'b-', label='训练损失')
plt.plot(epochs, test_loss, 'r-', label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('损失值')
plt.title('训练和测试损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# 9. 模型训练配置与执行
def main():
# 选择模型类型: 'custom' (自定义CNN), 'resnet50', 'vgg16', 'mobilenet_v2'
model_type = 'resnet50' # 可更换为其他模型
# 选择注意力机制: None, 'se', 'cbam'
attention_type = 'cbam' # 可更换为其他注意力类型或None
# 训练参数
epochs = 30 # 预训练模型通常需要更少的epochs
num_classes = 20
# 初始化模型
if model_type == 'custom':
print(f"使用自定义CNN模型,注意力机制: {attention_type}")
model = ImprovedCNN(num_classes=num_classes, attention_type=attention_type).to(device)
else:
print(f"使用预训练{model_type}模型,注意力机制: {attention_type}")
# model = create_pretrained_model(
# model_name=model_type,
# num_classes=num_classes,
# freeze_feature=False, # 设为True表示只训练顶层,False表示微调整个模型
# attention_type=attention_type
# ).to(device)
# 使用预训练模型,先冻结特征层
model = create_pretrained_model(
model_name=model_type,
num_classes=num_classes,
freeze_feature=True, # 先冻结特征层,只训练顶层
attention_type=None # 禁用注意力
).to(device)
# 定义损失函数、优化器和学习率调度器
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(
# optimizer, mode='min', patience=3, factor=0.5
# )
# 调整优化器和学习率
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 更小的学习率
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=5, factor=0.5, min_lr=1e-6
)
# 开始训练
print(f"开始训练...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
# 保存模型
model_filename = f"{model_type}_{attention_type if attention_type else 'no_att'}_fish_model.pth"
torch.save(model.state_dict(), model_filename)
print(f"模型已保存为: {model_filename}")
if __name__ == "__main__":
main()