import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings
def main():
# 配置参数
config = {
'data_dir': '/kaggle/input/dogs-vs-cats', # Kaggle猫狗数据集路径
'batch_size': 32,
'num_epochs': 5,
'learning_rate': 0.001,
'img_size': 224,
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
'best_model_path': 'best_model.pth'
}
print(f"使用设备: {config['device']}")
# 加载数据
print("加载数据...")
try:
train_loader, test_loader, class_names = create_data_loaders(
data_dir=config['data_dir'],
batch_size=config['batch_size'],
img_size=config['img_size']
)
print(f"发现 {len(class_names)} 个类别: {class_names}")
except Exception as e:
print(f"数据加载失败: {e}")
return
# 初始化模型
print("初始化模型...")
model = CNNModel(num_classes=len(class_names), pretrained=True)
model = model.to(config['device'])
# 设置优化器和损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
# 训练模型
print(f"开始训练 ({config['num_epochs']} 个周期)...")
try:
trained_model = train_model(
model=model,
train_loader=train_loader,
test_loader=test_loader,
criterion=criterion,
optimizer=optimizer,
num_epochs=config['num_epochs'],
device=config['device'],
save_path=config['best_model_path']
)
except KeyboardInterrupt:
print("训练被用户中断")
return
# 加载最佳模型
if os.path.exists(config['best_model_path']):
print("加载最佳模型...")
trained_model.load_state_dict(torch.load(config['best_model_path']))
else:
warnings.warn("未找到最佳模型,使用最后训练的模型")
# Grad-CAM可视化
print("生成Grad-CAM可视化...")
test_dir = os.path.join(config['data_dir'], 'test', '*')
sample_images = []
# 获取测试图像样本
for class_name in class_names:
class_dir = os.path.join(config['data_dir'], 'test', class_name)
images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
glob.glob(os.path.join(class_dir, '*.png'))
if images:
sample_images.append(images[0])
print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
else:
print(f"警告: 类别 '{class_name}' 未找到测试图像")
# 执行可视化
for img_path in sample_images:
try:
visualize_grad_cam(
img_path=img_path,
model=trained_model,
class_names=class_names,
transform=train_loader.dataset.transform,
device=config['device']
)
except Exception as e:
print(f"处理图像 {img_path} 时出错: {e}")
if __name__ == "__main__":
main()
@浙大疏锦行