Pytorch-03数据的Transform

发布于:2025-08-05 ⋅ 阅读:(13) ⋅ 点赞:(0)

Transforms

原始的数据形式可能并不符合模型算法所要求的输入形式。例如一个图片刚读入内存的时候可能还是numpy形式,而模型输入需要是tensor形式等等。亦或者模型要求的标签是one-hot,而现在的标签是整数等等的情况。

为了规范、统一的解决这个问题,pytorch定义了transform来在Dataset的__getitem__阶段,对数据进行处理,然后再交给Dataloader,再交给模型以供训练或推理。

在TorchVision中,所有的datasets都有两个可以指定的transform参数:

  • transform: 定义要怎么处理初始的features(数据样本)
  • target_transform:定义要怎么处理初始标签

torchvision.transforms模块提供了很多开箱即用的,常用的转换方法

对于FashionMNIST数据集,图片是PIL图片格式, 标签是整数形式,为了能进行分类训练,我们需要把图片转换成归一化之后的tensors,并且把整数标签转换为one-hot编码之后的tensor。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
	root="data",
	train=True,
	download=True,
	transform=ToTensor(),
	target_transform=Lambda(lambda y; torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))

ToTensor()

这个方法会把一个PIL图片或者ndarray转换成FloatTensor, 并且把图片的像素值归一化到[0, 1]之间。

归一化对训练又很多好处,如加速训练,避免梯度爆炸或者梯度消失,让训练更加稳定。

Lambda Transforms

你可以用torchvision.transforms.Lambda将任何简单的函数或 lambda 表达式作为转换器来使用。这里是定义了一个标签转换为one-hot编码tensor的匿名函数。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) 

lambda y 表明是一个匿名函数,接受一个参数y,也就是标签的整数值,然后利用torch.zeros(10, dtype=torch.float)创建一个全0的一维张量,最后使用scatter_(dim=0. index=torch.tensor(y), value=1)就地操作把自己index为y的元素赋值为1,这样就实现了one-hot编码。


网站公告

今日签到

点亮在社区的每一天
去签到