Day52
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1000, shuffle=False)
# 定义基础CNN模型
class SimpleCNN(nn.Module):
def __init__(self, num_filters1=16, num_filters2=32, kernel_size=3, dropout_rate=0.25):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, num_filters1, kernel_size=kernel_size, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(num_filters1, num_filters2, kernel_size=kernel_size, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Dropout(dropout_rate),
nn.Linear(num_filters2 * 7 * 7, 10)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.fc(x)
return x
# 训练和评估函数
def train_and_evaluate(model, epochs=5, lr=0.001, batch_size=64, momentum=0.9):
# 重新设置数据加载器的batch_size
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = []
test_accuracies = []
for epoch in range(epochs):
# 训练阶段
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
avg_train_loss = train_loss / len(train_loader)
train_losses.append(avg_train_loss)
# 测试阶段
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
test_accuracies.append(accuracy)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
f'({accuracy:.2f}%)\n')
return train_losses, test_accuracies
# 比较不同参数设置的效果
def compare_parameters():
# 基础模型参数
base_params = {
'num_filters1': 16,
'num_filters2': 32,
'kernel_size': 3,
'dropout_rate': 0.25
}
# 优化后的参数
optimized_params = {
'num_filters1': 32, # 增加第一层卷积核数量
'num_filters2': 64, # 增加第二层卷积核数量
'kernel_size': 3,
'dropout_rate': 0.5 # 增加dropout率防止过拟合
}
# 训练基础模型
print("训练基础模型...")
base_model = SimpleCNN(**base_params)
base_train_losses, base_test_accs = train_and_evaluate(base_model, epochs=10, lr=0.001)
# 训练优化模型
print("\n训练优化模型...")
optimized_model = SimpleCNN(**optimized_params)
# 增加训练轮次
optimized_train_losses, optimized_test_accs = train_and_evaluate(
optimized_model, epochs=15, lr=0.0005, batch_size=128)
# 绘制比较图表
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.plot(base_train_losses, label='基础模型训练损失')
plt.plot(optimized_train_losses, label='优化模型训练损失')
plt.xlabel('轮次')
plt.ylabel('损失')
plt.legend()
plt.title('训练损失比较')
plt.subplot(1, 2, 2)
plt.plot(base_test_accs, label='基础模型测试准确率')
plt.plot(optimized_test_accs, label='优化模型测试准确率')
plt.xlabel('轮次')
plt.ylabel('准确率 (%)')
plt.legend()
plt.title('测试准确率比较')
plt.tight_layout()
plt.savefig('cnn_parameter_comparison.png')
plt.show()
# 打印最终准确率
print(f"基础模型最终测试准确率: {base_test_accs[-1]:.2f}%")
print(f"优化模型最终测试准确率: {optimized_test_accs[-1]:.2f}%")
return base_model, optimized_model
# 运行参数比较
if __name__ == "__main__":
base_model, optimized_model = compare_parameters()