深度学习:自定义数据集处理、数据增强与最优模型管理

发布于:2025-09-05 ⋅ 阅读:(20) ⋅ 点赞:(0)

目录

一、整体流程概述

二、自定义数据集处理

1. 数据集结构设计

2. 数据集索引文件生成

三、自定义数据集加载类

四、数据增强策略

五、卷积神经网络模型构建

六、模型训练与最优模型管理

1. 训练与测试函数

2. 训练主流程

七、模型加载与推理


一、整体流程概述

深度学习项目通常遵循 "数据→模型→训练→部署" 的闭环流程,本文以食物识别任务为例,详解从自定义数据集处理到模型训练、优化及调用的完整流程。核心环节包括:

  1. 自定义数据集构建与预处理
  2. 数据增强策略设计
  3. 卷积神经网络 (CNN) 模型构建
  4. 模型训练与最优模型保存
  5. 训练后模型调用与推理

二、自定义数据集处理

1. 数据集结构设计

通常采用 "类别目录 + 图片文件" 的组织结构:

food_dataset/
├── train/
│   ├── 苹果/
│   │   ├── apple1.jpg
│   │   └── apple2.jpg
│   └── 香蕉/
│       └── ...
└── test/
    └── ...

2. 数据集索引文件生成

需要将图片路径与标签映射关系保存为文本文件(如train.txttest.txt),方便模型加载。

核心代码(食物识别 1 - 预处理.py)

import os

def train_test_file(root, dir):
    file_txt = open(dir + '.txt', 'w+', encoding='utf-8')
    file_canzhao = open('canzhao.txt', 'w+', encoding='utf-8')
    path = os.path.join(root, dir)
    dirs = []  # 存储类别目录名
    recorded_classes = set()  # 去重集合

    for roots, directories, files in os.walk(path):
        # 获取一级子目录作为类别
        if not dirs and roots == path:
            dirs = directories
        
        if files:  # 处理图片文件
            current_class = roots.split(os.sep)[-1]  # 提取类别名
            class_index = dirs.index(current_class)  # 分配类别索引
            # 写入图片路径和标签
            for file in files:
                img_path = os.path.join(roots, file)
                file_txt.write(f"{img_path} {class_index}\n")
            # 写入类别-索引映射(去重)
            if current_class not in recorded_classes:
                file_canzhao.write(f"{current_class} {class_index}\n")
                recorded_classes.add(current_class)

    file_txt.close()
    file_canzhao.close()

# 生成训练集和测试集索引文件
root = r'..\\food_dataset'
train_test_file(root, 'train')
train_test_file(root, 'test')
..\\food_dataset\train\八宝粥\img_八宝粥罐_22.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_29.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_65.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_68.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_84.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_98.jpeg 0
..\\food_dataset\train\哈密瓜\img_水果_103.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_13.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_136.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_142.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_163.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_174.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_191.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_209.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_238.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_30.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_34.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_42.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_44.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_57.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_81.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_92.jpeg 1
..\\food_dataset\train\圣女果\img_圣女果_104.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_105.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_116.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_12.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_13.jpeg 2
略

知识点解析

  • os.walk():递归遍历目录结构,获取所有文件路径
  • os.sep:适配不同操作系统的路径分隔符(Windows 用\,Linux 用/
  • 类别索引映射:通过dirs.index(current_class)建立类别与数字索引的映射,便于模型处理

三、自定义数据集加载类

使用 PyTorch 的Dataset类封装数据集,实现数据的按需加载。

核心代码

from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs = []
        self.labels = []
        self.transform = transform
        # 从索引文件读取数据
        with open(file_path, 'r', encoding='utf-8') as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)  # 返回样本总数

    def __getitem__(self, idx):
        # 读取图片并应用转换
        image = Image.open(self.imgs[idx]).convert('RGB')  # 确保RGB格式
        if self.transform:
            image = self.transform(image)
        # 标签转换为Tensor
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

知识点解析

  • Dataset抽象类:必须实现__len__(返回样本数)和__getitem__(获取单个样本)方法
  • 延迟加载:仅在需要时才读取图片,节省内存
  • 数据转换接口:通过transform参数灵活接入数据增强 pipeline

四、数据增强策略

数据增强是提升模型泛化能力的关键手段,通过对训练数据进行随机变换,增加数据多样性。

核心代码(食物识别 1 - 数据增强.py)

from torchvision import transforms

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((300, 300)),  # 缩放至300x300
        transforms.RandomRotation(45),  # 随机旋转(-45°~45°)
        transforms.CenterCrop(256),  # 中心裁剪至256x256
        transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转
        transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转
        transforms.ColorJitter(  # 颜色抖动
            brightness=0.2,  # 亮度
            contrast=0.1,    # 对比度
            saturation=0.1,  # 饱和度
            hue=0.1          # 色调
        ),
        transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图
        transforms.ToTensor(),  # 转为Tensor(0-1范围)
        transforms.Normalize(  # 标准化
            [0.485, 0.456, 0.406],  # 均值
            [0.229, 0.224, 0.225]   # 标准差
        )
    ]),
    'valid': transforms.Compose([  # 验证集仅做必要转换
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

知识点解析

  • 训练集增强原则:保留语义信息的同时增加多样性(旋转、翻转、颜色变化等)
  • 验证集处理:仅做必要的尺寸调整和标准化,保证评估一致性
  • 标准化(Normalize):使用 ImageNet 数据集的均值和标准差,使输入分布更稳定
  • Compose:将多个变换组合成一个 pipeline,按顺序执行

五、卷积神经网络模型构建

设计适用于食物识别的 CNN 模型,通过卷积层提取视觉特征,全连接层完成分类。

核心代码

import torch
from torch import nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 卷积块1:特征提取+降维
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,  # 输入通道数(RGB)
                out_channels=16,  # 卷积核数量
                kernel_size=5,  # 卷积核大小5x5
                stride=1,  # 步长
                padding=2  # 填充,保持尺寸
            ),
            nn.ReLU(),  # 激活函数
            nn.MaxPool2d(kernel_size=2)  # 2x2池化,尺寸减半
        )
        # 卷积块2:加深特征提取
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        # 卷积块3:进一步提取+降维
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # 卷积块4:高级特征提取
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        # 全连接层:分类输出
        self.line = nn.Linear(128 * 64 * 64, 20)  # 20类食物

    def forward(self, x):
        # 前向传播路径
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)  # 展平特征图
        out = self.line(x)
        return out

知识点解析

  • 卷积层作用:通过滑动窗口提取局部特征(边缘、纹理、形状等)
  • 池化层作用:降低特征图尺寸,减少参数数量,增强平移不变性
  • 激活函数(ReLU):引入非线性,使模型能拟合复杂特征关系
  • 全连接层:将卷积提取的高维特征映射到类别空间

六、模型训练与最优模型管理

1. 训练与测试函数

# 训练函数
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 训练模式(启用dropout等)
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)  # 数据移至计算设备
        pred = model(X)  # 前向传播
        loss = loss_fn(pred, y)  # 计算损失
        
        # 反向传播与参数更新
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新参数
        
        # 打印训练进度
        if batch_size_num % 100 == 1:
            print(f"Loss: {loss.item():.4f} [batch:{batch_size_num}]")
        batch_size_num += 1

# 测试函数(含最优模型保存)
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    correct = 0
    loss_sum = 0
    model.eval()  # 评估模式(关闭dropout等)
    with torch.no_grad():  # 关闭梯度计算
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
            loss_sum += loss.item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    acc = correct / size
    latest_loss = loss_sum / num_batches
    print(f"Accuracy: {(acc * 100)}%, Loss: {latest_loss:.4f}")
    
    # 保存最优模型
    if acc > best_acc:
        best_acc = acc
        torch.save(model, 'best1.pt')  # 保存完整模型
        print(f"保存最优模型,准确率:{best_acc*100:.2f}%")

2. 训练主流程

# 设备选择
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f'Using device: {device}')

# 加载数据集
train_data = food_dataset('train.txt', data_transforms['train'])
test_data = food_dataset('test.txt', data_transforms['valid'])
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()  # 多分类损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  # Adam优化器

# 执行训练
epochs = 100
print('训练开始')
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f'训练结束, 最佳准确率:{best_acc*100:.2f}%')

知识点解析

  • 设备加速:自动选择 GPU(CUDA/MPS)或 CPU 进行计算
  • 训练模式与评估模式:model.train()model.eval()控制 dropout、BN 层等行为
  • 梯度管理:optimizer.zero_grad()避免梯度累积,with torch.no_grad()节省评估时内存
  • 最优模型保存:通过跟踪验证集准确率,只保存表现最好的模型,避免过拟合模型

七、模型加载与推理

训练完成后,加载最优模型进行实际预测。

核心代码(食物识别 1 - 调用最优模型.py)

# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.load('best1.pt', map_location=device)  # 加载完整模型
model.eval()  # 切换到评估模式

# 加载类别映射
def load_class_mapping(canzhao_path='canzhao.txt'):
    index_to_name = {}
    with open(canzhao_path, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            food_name, index = line.strip().split(' ')
            index_to_name[int(index)] = food_name
    return index_to_name

# 单张图片预测
def predict_image(image_path, model, index_to_name):
    try:
        # 图片预处理(与验证集一致)
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0)  # 增加批次维度
        image = image.to(device)
        
        # 推理
        with torch.no_grad():
            pred = model(image)
            pred_index = pred.argmax(1).item()  # 获取最高概率类别
        
        return index_to_name.get(pred_index, "未知类别")
    except Exception as e:
        return f"预测失败:{str(e)}"

# 交互预测
index_to_name = load_class_mapping()
while True:
    img_path = input("请输入图片路径(q退出):")
    if img_path.lower() == 'q':
        break
    print("预测结果:", predict_image(img_path, model, index_to_name))

知识点解析

  • 模型加载:torch.load()加载保存的模型,map_location适配不同计算设备
  • 推理预处理:必须与训练时的验证集处理完全一致,否则会导致分布不匹配
  • 批次维度:模型输入需为(batch_size, channel, height, width),通过unsqueeze(0)添加批次维度
  • 类别映射:将模型输出的数字索引转换为实际类别名称

网站公告

今日签到

点亮在社区的每一天
去签到