创建自定义Dataset类与多分类问题实战

发布于:2025-07-16 ⋅ 阅读:(17) ⋅ 点赞:(0)

codes

🌟 6 多分类问题与卷积模型的优化

数据集:Multi-class Weather Dataset
注:与第5章使用的数据集不同,本数据集为多分类任务且所有图片存储在同一文件夹

# 导入基础库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import glob  # 文件路径操作
from torchvision import transforms  # 数据预处理
from PIL import Image  # 图像处理
from torch.utils import data  # 数据集构建

🧩 6.1 创建自定义Dataset类

⚠️ 数据集特点:

  1. 多分类标签(非二分类)
  2. 所有图片存储在单一文件夹
  3. 未划分训练集/测试集

不能使用 torchvision.datasets.ImageFolder 加载该数据集
✅ 需通过继承 torch.utils.data.Dataset 实现自定义数据集类

🔑 关键实现步骤:

# 获取所有图片路径
imgs = glob.glob(r'D:/my_all_learning/dataset2/dataset2/*.jpg') 
print(imgs[:3])  # 查看前3个路径

# 定义类别映射
species = ['cloudy','rain','shine','sunrise']  # 4个类别
species_to_idx = dict((c,i) for i,c in enumerate(species))  # 类别→索引
idx_to_species = dict((i,c) for i,c in enumerate(species))  # 索引→类别
print(species_to_idx)
print(idx_to_species)

# 生成标签列表
labels = []
for img in imgs:
    for i,c in enumerate(species):
        if c in img:  # 根据路径名判断类别
            labels.append(i)
print(labels[:3])

# 定义图像预处理
transform = transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

🛠️ 自定义Dataset类实现

class WT_Dataset(data.Dataset):
    def __init__(self, imgs_path, labels):
        self.imgs_path = imgs_path
        self.labels = labels
        
    def __len__(self):
        return len(self.imgs_path)
    
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        label = self.labels[index]
        pil_img = Image.open(img_path)
        pil_img = pil_img.convert('RGB')  # 确保RGB格式
        pil_img = transform(pil_img)  # 应用预处理
        return pil_img, label

📊 数据集划分与可视化

# 创建数据集实例
dataset = WT_Dataset(imgs, labels)
print(f"数据集总量: {len(dataset)}")

# 划分训练集(80%)和测试集(20%)
train_count = int(0.8 * len(dataset))
test_count = len(dataset) - train_count
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])
print(f"训练集: {len(train_dataset)}, 测试集: {len(test_dataset)}")

# 创建DataLoader
BATCH_SIZE = 16
train_dl = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 可视化首个batch的图像
imgs_batch, labels_batch = next(iter(train_dl))
plt.figure(figsize=(12,8))
for i, (img, label) in enumerate(zip(imgs_batch[:6], labels_batch[:6])):
    img = (img.permute(1,2,0).numpy() + 1)/2  # 反归一化+通道重排
    plt.subplot(2,3,i+1)
    plt.title(idx_to_species.get(label.item()))  # 显示类别名
    plt.imshow(img)

🧠 6.2 基础卷积模型

📐 网络结构设计

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 卷积层定义
        self.conv1 = nn.Conv2d(3, 16, 3)    # 输入通道3, 输出16, 卷积核3x3
        self.conv2 = nn.Conv2d(16, 32, 3)   # 通道数16→32
        self.conv3 = nn.Conv2d(32, 64, 3)   # 通道数32→64
        # 全连接层定义
        self.fc1 = nn.Linear(64*10*10, 1024)  # 展平后输入
        self.fc2 = nn.Linear(1024, 4)        # 输出4分类
    
    def forward(self, x):
        # [batch, 3, 96, 96] → 卷积1 → [batch, 16, 94, 94]
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)  # → [batch, 16, 47, 47]
        
        # → 卷积2 → [batch, 32, 45, 45]
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)  # → [batch, 32, 22, 22] (45/2取整)
        
        # → 卷积3 → [batch, 64, 20, 20]
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)  # → [batch, 64, 10, 10]
        
        # 展平 → 全连接
        x = x.view(-1, 64*10*10)
        x = F.relu(self.fc1(x))  # → [batch, 1024]
        x = self.fc2(x)          # → [batch, 4]
        return x

⚙️ 训练配置

# 设备选择
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

# 模型初始化
model = Net().to(device)
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Adam优化器

🔁 训练与测试函数

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_loss, correct = 0, 0
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        # 前向传播
        pred = model(X)
        loss = loss_fn(pred, y)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 统计指标
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
    
    train_loss /= num_batches
    correct /= size
    return train_loss, correct

def test(dataloader, model):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    return test_loss, correct

📈 模型训练与评估

epochs = 20
train_loss, train_acc = [], []
test_loss, test_acc = [], []

for epoch in range(epochs):
    # 训练周期
    epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
    # 测试周期
    epoch_test_loss, epoch_test_acc = test(test_dl, model)
    
    # 记录指标
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    
    # 打印日志
    template = ("epoch:{:2d}, train_loss:{:.5f}, train_acc:{:.1f}%, "
                "test_loss:{:.5f}, test_acc:{:.1f}%")
    print(template.format(epoch, epoch_loss, epoch_acc*100, 
                         epoch_test_loss, epoch_test_acc*100))

print("训练完成!")

📉 结果可视化

# 损失曲线
plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
plt.title("训练与测试损失对比")
plt.show()

# 准确率曲线
plt.plot(range(1, epochs+1), train_acc, label='train_acc')
plt.plot(range(1, epochs+1), test_acc, label='test_acc')
plt.legend()
plt.title("训练与测试准确率对比")
plt.show()

🚨 关键问题:过拟合现象

下一讲将介绍卷积网络优化技术(Dropout、BN、学习率衰减)提升泛化能力


关键词:多分类 卷积神经网络 自定义数据集 过拟合 PyTorch


网站公告

今日签到

点亮在社区的每一天
去签到