目录
一、Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
一、Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
1、MNIST数据处理
在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。
- DataLoader类:决定数据如何加载
- Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。
为了引入这些概念,现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。
torchvision
├── datasets # 视觉数据集(如 MNIST、CIFAR)
├── transforms # 视觉数据预处理(如裁剪、翻转、归一化)
├── models # 预训练模型(如 ResNet、YOLO)
├── utils # 视觉工具函数(如目标检测后处理)
└── io # 图像/视频 IO 操作
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作
# 先归一化,再标准化
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量并归一化到[0,1]
transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
这里稍微有点反逻辑,正常思路应该是先有数据集,后续再处理。但是在pytorch的思路是,数据在加载阶段就处理结束。
2、Dataset类
现在我们想要取出来一个图片,看看长啥样,因为datasets.MNIST本质上集成了torch.utils.data.Dataset,所以自然需要有对应的方法。
import matplotlib.pyplot as plt
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
这里很难理解,为什么train_dataset[sample_idx]可以获取到图片和标签,是因为 datasets.MNIST这个类继承了torch.utils.data.Dataset类,这个类中有一个方法__getitem__,这个方法会返回一个tuple,tuple中第一个元素是图片,第二个元素是标签。
详细介绍下torch.utils.data.Dataset类
PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:
- len():返回数据集的样本总数。
- getitem(idx):根据索引idx返回对应样本的数据和标签。
PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。
在 Python 中,getitem__和__len 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:
__getitem__方法
__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。
# 示例代码
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __getitem__(self, idx):
return self.data[idx]
# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。
__len__方法
__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __len__(self):
return len(self.data)
# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj)) # 输出:5
这里定义的__len__方法,使得MyList类的实例可以像普通列表一样被len()函数调用获取长度。
# minist数据集的简化版本
class MNIST(Dataset):
def __init__(self, root, train=True, transform=None):
# 初始化:加载图片路径和标签
self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签
self.transform = transform # 预处理操作
def __len__(self):
return len(self.data) # 返回样本总数
def __getitem__(self, idx): # 获取指定索引的样本
# 获取指定索引的图像和标签
img, target = self.data[idx], self.targets[idx]
# 应用图像预处理(如ToTensor、Normalize)
if self.transform is not None: # 如果有预处理操作
img = self.transform(img) # 转换图像格式
# 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
return img, target # 返回处理后的图像和标签
- Dataset = 厨师(准备单个菜品)
- DataLoader = 服务员(将菜品按订单组合并上桌)
预处理(如切菜、调味)属于厨师的工作,而非服务员。所以在dataset就需要添加预处理步骤。
# 可视化原始图像(需要反归一化)
def imshow(img):
img = img * 0.3081 + 0.1307 # 反标准化
npimg = img.numpy()
plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
plt.show()
print(f"Label: {label}")
imshow(image)
二、Dataloader类
# 3. 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
shuffle=True # 随机打乱数据
)
test_loader = DataLoader(
test_dataset,
batch_size=1000 # 每个批次1000张图片
# shuffle=False # 测试时不需要打乱数据
)
总结:
维度 | Dataset |
DataLoader |
---|---|---|
核心职责 | 定义“数据是什么”和“如何获取单个样本” | 定义“如何批量加载数据”和“加载策略” |
核心方法 | __getitem__ (获取单个样本)、__len__ (样本总数) |
无自定义方法,通过参数控制加载逻辑 |
预处理位置 | 在__getitem__ 中通过transform 执行预处理 |
无预处理逻辑,依赖Dataset 返回的预处理后数据 |
并行处理 | 无(仅单样本处理) | 支持多进程加载(num_workers>0 ) |
典型参数 | root (数据路径)、transform (预处理) |
batch_size 、shuffle 、num_workers |
核心结论
Dataset
类:定义数据的内容和格式(即“如何获取单个样本”),包括:- 数据存储路径/来源(如文件路径、数据库查询)。
- 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。
- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过
transform
参数实现)。 - 返回值格式(如
(image_tensor, label)
)。
DataLoader
类:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:- 批量大小(
batch_size
)。 - 是否打乱数据顺序(
shuffle
)。
- 批量大小(
三、minist手写数据集的了解
1、MNIST 数据集基本概况
MNIST(Modified National Institute of Standards and Technology database)是机器学习领域最经典的入门级数据集之一,由纽约大学的 Yann LeCun 等人整理,被广泛用于图像分类算法的开发与测试。
数据来源:
- 训练集:来自 250 个不同人手写的数字,其中 50% 是高中学生,50% 是人口普查局员工。
- 测试集:与训练集独立的另一组手写数字样本。
数据规模:
- 训练集:60,000 张灰度图像
- 测试集:10,000 张灰度图像
- 每张图像尺寸:28×28 像素
数据格式:
- 图像:单通道灰度图,像素值范围为 0-255(归一化后通常为 0-1)
- 标签:0-9 的数字分类,共 10 个类别
2、MNIST 数据集的特点与应用场景
特点:
- 数据标准化:所有图像已统一尺寸并中心化处理
- 噪声较少:经过筛选和预处理,适合入门级算法验证
- 类别平衡:每个数字的样本数量相对均匀
典型应用场景:
- 机器学习入门教程的标准案例
- 新算法的基准测试(如 CNN、神经网络等)
- 迁移学习的源领域数据
- 模型轻量化与压缩技术的测试平台
3、MNIST 数据的存储与读取
MNIST 数据集通常以 IDX 格式存储,这是一种用于存储多维度数组的简单文件格式:
文件结构:
train-images-idx3-ubyte
:训练集图像(60,000×28×28)train-labels-idx1-ubyte
:训练集标签(60,000 个)t10k-images-idx3-ubyte
:测试集图像(10,000×28×28)t10k-labels-idx1-ubyte
:测试集标签(10,000 个)
4、MNIST 数据集的性能基准
由于数据规模适中且难度较低,MNIST 已成为衡量算法性能的标准之一:
传统机器学习算法:
- 支持向量机(SVM):准确率约 97%-98%
- 随机森林:准确率约 95%-97%
深度学习算法:
- 简单全连接神经网络:准确率约 98%-99%
- 卷积神经网络(CNN):准确率可达 99.5% 以上
- 最新技术(如胶囊网络):准确率接近 99.7%
5、MNIST 的扩展与变种
- EMNIST:扩展版 MNIST,包含更多字符类别(大写字母、小写字母、数字)
- Fashion-MNIST:由 Zalando 提供的服装图像数据集,结构与 MNIST 完全一致,更具挑战性
- KMNIST:日本手写假名数据集,格式与 MNIST 兼容
- MNIST-C:添加了各种噪声和干扰的 MNIST 变种,用于测试模型鲁棒性
四、了解下cifar数据集,尝试获取其中一张图片
1、CIFAR 数据集概述
CIFAR(Canadian Institute For Advanced Research)数据集是计算机视觉领域的经典基准数据集,由加拿大高级研究所的 Alex Krizhevsky、Vinod Nair 和 Geoffrey Hinton 创建,主要用于图像分类任务的模型训练与评估。该数据集分为 CIFAR-10 和 CIFAR-100 两个版本,两者在数据结构和应用场景上既有相似性又有明显区别。
2、CIFAR-10 数据集详解
a. 数据规模与结构
- 图像数量:共 60,000 张 32×32 的彩色图像,其中 50,000 张训练集、10,000 张测试集。
- 类别分布:10 个大类,每个类别包含 6,000 张图像,类别如下:
类别 示例图像 airplane 飞机 automobile 汽车 bird 鸟类 cat 猫 deer 鹿 dog 狗 frog 青蛙 horse 马 ship 船 truck 卡车
b. 数据特点
- 色彩与尺寸:RGB 三通道彩色图像,固定尺寸 32×32,像素值范围 [0, 255]。
- 难度挑战:图像分辨率低、目标物体占据像素少,且存在背景干扰、姿态变化等问题,对模型识别能力要求较高。
3、CIFAR-100 数据集详解
a. 数据规模与结构
- 图像数量:同 CIFAR-10,共 60,000 张 32×32 彩色图像(50,000 训练 + 10,000 测试)。
- 类别分布:100 个细分类别,每个类别包含 600 张图像,类别组织为 20 个超类(每个超类包含 5 个子类),例如:
- 超类 “动物”:包含 bear(熊)、tiger(老虎)、lion(狮子)等子类;
- 超类 “交通工具”:包含 car(汽车)、train(火车)、truck(卡车)等子类。
b. 与 CIFAR-10 的核心区别
维度 | CIFAR-10 | CIFAR-100 |
---|---|---|
类别数量 | 10 个大类 | 100 个细分类别(20 超类) |
分类难度 | 较低(类别差异明显) | 较高(细分类别易混淆) |
典型应用 | 基础模型验证 | 细粒度分类研究 |
4、数据预处理与加载(PyTorch 示例)
标准化参数
- CIFAR-10:
均值(0.4914, 0.4822, 0.4465)
,标准差(0.2023, 0.1994, 0.2010)
- CIFAR-100:
均值(0.5071, 0.4867, 0.4408)
,标准差(0.2675, 0.2565, 0.2761)
5、CIFAR 数据集的应用与挑战
a. 典型应用场景
- 模型性能评估:如 ResNet、DenseNet 等经典网络常以 CIFAR 为基准测试分类准确率;
- 数据增强研究:由于数据量有限,常用于验证数据增强技术(如 Cutout、Mixup)的效果;
- 半监督学习与迁移学习:小样本场景下的算法验证(如 FixMatch、SimCLR)。
b. 主要挑战
- 小尺寸图像:32×32 像素难以捕捉复杂细节,需模型具备强特征提取能力;
- 类别相似度:CIFAR-100 中同类超类的子类(如不同品种的狗)外观高度相似,分类难度大;
- 过拟合问题:训练集规模有限(5 万张),深度模型易出现过拟合,需结合正则化策略(如 Dropout、L2 正则)。
6、与 MNIST 数据集的对比
维度 | MNIST | CIFAR-10 | CIFAR-100 |
---|---|---|---|
图像类型 | 灰度(1 通道) | 彩色(3 通道) | 彩色(3 通道) |
图像尺寸 | 28×28 | 32×32 | 32×32 |
类别数量 | 10(手写数字) | 10(物体大类) | 100(物体细分类) |
任务难度 | 低(入门级) | 中(基础研究) | 高(进阶研究) |
典型准确率 | CNN 可达 99%+ | 主流模型约 95% | 主流模型约 85% |
7、CIFAR-10图片获取
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
torch.manual_seed(42)
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量并归一化到[0,1]
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10数据集的均值和标准差
])
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=False,
transform=transform
)
# 随机选择一张图片,可以重复运行,每次都会随机选择
cifar_sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
cifar_image, cifar_label = train_dataset[cifar_sample_idx] # 获取图片和标签
import numpy as np
import matplotlib.pyplot as plt
# 定义CIFAR-10的类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 可视化CIFAR-10图像(需要反归一化)
def imshow_cifar(img, label=None):
# 使用CIFAR-10的均值和标准差进行反归一化
img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(3, 1, 1) + np.array([0.4914, 0.4822, 0.4465]).reshape(3, 1, 1)
npimg = img.numpy()
# 调整通道顺序:[C,H,W] → [H,W,C]
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 显示标签(如果提供)
if label is not None:
plt.title(f"Label: {label} ({classes[label]})")
plt.axis('off')
plt.show()
# 使用示例
print(f"Label: {cifar_label}")
imshow_cifar(cifar_image, cifar_label)
8、CIFAR-100图片获取
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
torch.manual_seed(42)
# CIFAR-100的归一化参数
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量并归一化到[0,1]
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100数据集的均值和标准差
])
# 加载CIFAR-100数据集
train_dataset = datasets.CIFAR100(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR100(
root='./data',
train=False,
download=False,
transform=transform
)
# 随机选择一张图片
cifar100_sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
cifar100_image, cifar100_label = train_dataset[cifar100_sample_idx]
# 定义CIFAR-100的类别名称(完整列表)
cifar100_classes = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum',
'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark',
'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel',
'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone',
'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle',
'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
# 可视化CIFAR-100图像(需要反归一化)
def imshow_cifar100(img, label=None):
# 使用CIFAR-100的均值和标准差进行反归一化
img = img * np.array([0.2675, 0.2565, 0.2761]).reshape(3, 1, 1) + np.array([0.5071, 0.4867, 0.4408]).reshape(3, 1, 1)
npimg = img.numpy()
# 调整通道顺序:[C,H,W] → [H,W,C]
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 显示标签
if label is not None:
plt.title(f"Label: {label} ({cifar100_classes[label]})")
plt.axis('off')
plt.show()
# 使用示例
print(f"Label: {cifar100_label}")
imshow_cifar100(cifar100_image, cifar100_label)