作业:
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化
进阶:并拆分成多个文件
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 数据预处理
transform = transforms.Compose([
transforms.Resize((64, 64)), # 调整图像大小
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载自定义数据集
dataset_path = r"F:\Program Files\MyPythonProjects\day43\music_instruments"
dataset = ImageFolder(root=dataset_path, transform=transform)
# 划分训练集和测试集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# 定义类别名称
classes = ('accordion', 'banjo', 'drum', 'flute', 'guitar',
'harmonica', 'saxophone', 'sitar', 'tabla', 'violin')
# 初始化模型
model = SimpleCNN()
print("模型已创建")
# 使用GPU或CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 训练模型
def train_model(model, epochs=10):
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=2
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 10 == 9:
print(f'[{epoch + 1}, {i + 1}] 损失: {running_loss / 10:.3f}')
running_loss = 0.0
print("训练完成")
# 训练或加载模型
try:
model.load_state_dict(torch.load('music_instruments_cnn.pth'))
print("已加载预训练模型")
except:
print("无法加载预训练模型,使用未训练模型或训练新模型")
train_model(model, epochs=10)
torch.save(model.state_dict(), 'music_instruments_cnn.pth')