PyTorch的DataLoader是数据加载的核心工具,可高效处理批量数据、并行加载和自动打乱。以下是一个结合实例的分步讲解:
1. 基础使用流程
import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集类(必须实现__len__和__getitem__)
class MyDataset(Dataset):
def __init__(self, data):
self.data = data # 假设data是已加载的列表或张量
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return sample # 返回单个样本
# 创建数据集实例
data = [torch.randn(3, 100, 100) for _ in range(100)] # 100张3通道100x100的假图片
dataset = MyDataset(data)
# 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=16, # 每批16个样本
shuffle=True, # 训练时打乱数据
num_workers=2, # 使用2个子进程加载数据
drop_last=True # 丢弃最后不足一个batch的数据
)
# 遍历数据
for batch in dataloader:
print(batch.shape) # 输出:torch.Size([16, 3, 100, 100])
2. 结合实际场景的完整示例
场景:图像分类任务(CIFAR10)
import torchvision
from torchvision import transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加载CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform # 应用预处理
)
# 创建DataLoader
train_loader = DataLoader(
train_set,
batch_size=64,
shuffle=True,
num_workers=4
)
# 训练循环示例
for epoch in range(10):
for images, labels in train_loader:
images = images.to('cuda') # 数据转移到GPU
labels = labels.to('cuda')
# 此处插入模型训练代码...
3. 关键参数详解
- batch_size:控制内存消耗与梯度稳定性
- shuffle:训练集=True(防止模型记忆顺序),测试集=False
- num_workers:根据CPU核心数调整(建议值:CPU核心数-1)
- collate_fn:自定义批次处理逻辑(处理不同尺寸数据时有用)
- pin_memory:当使用GPU时=True(加速数据到GPU的传输)
4. 处理非对齐数据(自定义collate_fn)
def collate_fn(batch):
# batch是包含多个__getitem__返回值的列表
images = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 对图像进行动态填充
images = torch.nn.utils.rnn.pad_sequence(images, batch_first=True)
labels = torch.tensor(labels)
return images, labels
loader = DataLoader(dataset, collate_fn=collate_fn)
5. 性能优化技巧
预加载数据:对于小数据集,使用TensorDataset直接加载到内存
data = torch.randn(1000, 3, 256, 256)
labels = torch.randint(0, 10, (1000,))
dataset = torch.utils.data.TensorDataset(data, labels)
多进程优化:设置num_workers后,建议禁用共享内存
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
混合精度训练:结合autocast使用
with torch.cuda.amp.autocast():
for data in dataloader:
# 训练代码...
6. 常见问题排查
- 内存不足:降低
batch_size
或使用梯度累积
- 数据加载慢:检查磁盘I/O速度,增加num_workers
- 数据不匹配:检查__getitem__返回的维度顺序是否与模型匹配