深度学习——基于卷积神经网络实现食物图像分类之(保存最优模型)

发布于:2025-09-04 ⋅ 阅读:(14) ⋅ 点赞:(0)

引言

本文将详细介绍如何使用PyTorch框架构建一个完整的食物图像分类系统,包含数据预处理、模型构建、训练优化以及模型保存等关键环节。与上一篇博客介绍的版本相比,本版本增加了模型保存与加载功能,并优化了测试评估流程。

一、项目概述

本项目的目标是构建一个能够识别20种不同食物的图像分类系统。主要技术特点包括:

  1. 简化但高效的数据预处理流程
  2. 三层CNN网络架构设计
  3. 训练过程中自动保存最佳模型
  4. 完整的训练-评估流程实现

二、环境配置

首先确保已安装必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、数据预处理

3.1 数据转换设置

我们为训练集和验证集定义了不同的转换策略:

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([256,256]),
        transforms.ToTensor(),
    ]),
    'valid': transforms.Compose([
        transforms.Resize([256,256]),
        transforms.ToTensor(),
    ]),
}

简化说明

  • 本版本简化了数据增强,仅保留基本的resize和tensor转换
  • 实际应用中可根据需求添加更多增强策略
3.2 数据集准备
def train_test_file(root, dir):
    file_txt = open(dir+'.txt','w')
    path = os.path.join(root,dir)
    for roots, directories, files in os.walk(path):
        if len(directories) != 0:
            dirs = directories
        else:
            now_dir = roots.split('\\')
            for file in files:
                path_1 = os.path.join(roots,file)
                file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')
    file_txt.close()

该函数会生成包含图像路径和标签的文本文件,格式为:

path/to/image1.jpg 0
path/to/image2.jpg 1
...

四、自定义数据集类

我们继承PyTorch的Dataset类实现自定义数据集:

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        return image, label

关键改进

  • 更清晰的数据加载逻辑
  • 完善的类型转换处理
  • 支持灵活的数据变换

五、CNN模型架构

我们设计了一个三层CNN网络:

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(64*32*32, 20)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        return self.out(x)

架构特点

  1. 每层包含卷积、ReLU激活和最大池化
  2. 使用padding保持特征图尺寸
  3. 最后通过全连接层输出分类结果

六、训练与评估流程

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_size_num % 1 == 0:
            print(f"loss: {loss.item():>7f} [batch:{batch_size_num}]")
        batch_size_num += 1
6.2 评估与模型保存
best_acc = 0

def Test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    
    # 保存最佳模型
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    
    print(f"\n测试结果: \n 准确率:{(100*correct):.2f}%, 平均损失:{test_loss:.4f}")

关键改进

  1. 增加全局变量best_acc跟踪最佳准确率
  2. 实现两种模型保存方式:(1)只保存模型参数(state_dict)(2)保存整个模型
  3. 更详细的测试结果输出

七、完整训练流程

# 初始化
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n{'-'*20}")
    train(train_dataloader, model, loss_fn, optimizer)

# 最终评估
Test(test_dataloader, model, loss_fn)

八、模型保存与加载

8.1 保存模型
# 方法1:只保存参数
torch.save(model.state_dict(), "model_params.pth")

# 方法2:保存完整模型
torch.save(model, "full_model.pt")
8.2 加载模型
# 方法1对应加载方式
model = CNN().to(device)
model.load_state_dict(torch.load("model_params.pth"))

# 方法2对应加载方式
model = torch.load("full_model.pt").to(device)

九、优化建议

  1. 数据增强:添加更多变换提高模型泛化能力
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率