【torchvision】2.5 torchvision 4大类: datasets、models、transforms、utils

发布于:2023-01-16 ⋅ 阅读:(533) ⋅ 点赞:(0)

官网地址:https://pytorch.org/vision/stable/index.html

Torchvision 是 PyTorch 的一个视觉处理工具包,独立于PyTorch,需要另外安装

它包括4个类,各类的主要功能如下:

  • 1)datasets:提供常用的数据集加载,设计上都是继承自torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet和COCO等。
  • 2)models:提供深度学习中各种经典的网络结构以及训练好的模型(参数选择 pretrained=True),包括AlexNet、VGG系列、ResNet系列、Inception系列等。
  • 3)transforms:常用的数据预处理操作,主要包括对 Tensor 及 PIL Image对象 的操作。
  • 4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是 save_img,它能将 Tensor 保存成图片
    在这里插入图片描述

1、torchvision.datasets

1.1 常用数据集加载 MNIST等

举例,通过torchvision下载 MNIST (mnist 全称:mixed national institute of standards and technology database)

train_dataset = torchvision.datasets.MNIST(root, 
                                           train=True, 
                                           transform=transform, 
                                           download=True)

root :需要下载至地址的根目录位置
train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True
transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本
download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载


1.2 自定义数据集读取 ImageFolder

torchvision.datasets.ImageFolder(root, transform, target_transform, loader)
  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是 “…/input/data/”
  • transform:对图片进行预处理操作(函数),原始图片作为输入,返回一个转换后的图片。
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可

另外,该 API 有以下成员变量:

  • self.classes:用一个 list 保存类别名称
  • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  • self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def getitem(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值

举例:数据存储结构如下

在这里插入图片描述

import torchvision
from torchvision import transforms, utils

trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('/Users/manmi/Desktop/data/data', transform=trans)

print(dataset.classes)   # ['bird', 'cat', 'dog']
print(dataset.class_to_idx)   # {'bird': 0, 'cat': 1, 'dog': 2}
print(dataset.imgs)   # [('/Users/manmi/Desktop/data/data/bird/bird1.jpeg', 0), ('/Users/manmi/Desktop/data/data/bird/bird2.jpeg', 0), ...]

print(len(dataset))   # 11
print(dataset[0][0].size())   # torch.Size([3, 400, 400])
print(dataset[0][1])   # 0

2、torchvision.models

torchvision.models 这个包中包含 alexnet、densenet、inception、resnet、squeezenet、vgg 等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型(用于迁移学习)

1)导入 resnet50 的预训练模型

import torchvision
model = torchvision.models.resnet50(pretrained=True)

2)如果只需要网络结构,不需要用预训练模型的参数来初始化,可以将参数设置为 pretrained=False,或者不设置 (参数 pretrained 默认是False)

from torchvision import models
model = torchvision.models.resnet50()

3、torchvision.transforms

3.1 对PIL Image的常见操作

1)转换为 tensor ToTensor()

ToTensor() 做了三件事:

  • 将每一个像素 灰度值 从 0~255 归一化到 [0,1],其归一化方法比较简单,直接除以255
  • 将 nump.ndarray 或 PIL.Image 转为 tensor,数据类型为 torch.FloatTensor
  • 将shape 由 (H,W, C) 转为shape为 (C, H, W)

2)中心裁剪 CenterCrop()

torchvision.transforms.CenterCrop(size)   # 所需裁剪的图片尺寸
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

img_src = Image.open('./bird.jpg')

img_1 = transforms.CenterCrop(200)(img_src)
img_2 = transforms.CenterCrop((200, 200))(img_src)
img_3 = transforms.CenterCrop((300, 200))(img_src)
img_4 = transforms.CenterCrop((500, 500))(img_src)

plt.subplot(231)
plt.imshow(img_src)
plt.subplot(232)
plt.imshow(img_1)
plt.subplot(233)
plt.imshow(img_2)
plt.subplot(234)
plt.imshow(img_3)
plt.subplot(235)
plt.imshow(img_4)
plt.show()

在这里插入图片描述

以上例子我们可知:
(1)如果切正方形,transforms.CenterCrop(100) 和 transforms.CenterCrop((100, 100)),两种写size的方法,效果一样
(2)如果设置的输出的图片尺寸大于原尺寸,会在边上补黑色


3)随机裁剪 RandomCrop()

# 依据给定的size随机裁剪
torchvision.transforms.RandomCrop(size, 
                      padding = None, 
                      pad_if_needed = False, 
                      fill=0, 
                      padding_mode ='constant')

功能:
从图片中随机裁剪出尺寸为 size 的图片,如果有 padding,那么先进行 padding,再随机裁剪 size 大小的图片。

参数:

  • size :所需裁剪的图片尺寸
  • padding: 设置填充大小
    • 当为 a 时,上下左右均填充 a 个像素
    • 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
    • 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
  • pad_if_needed:当图片小于设置的 size,是否填
  • padding_mode
    • constant: 像素值由 fill 设定 (默认)
    • edge: 像素值由图像边缘像素设定
    • reflect: 镜像填充,最后一个像素不镜像。([1,2,3,4] -> [3,2,1,2,3,4,3,2])
    • symmetric: 镜像填充,最后一个像素也镜像。([1,2,3,4] -> [2,1,1,2,3,4,4,4,3])
  • fill:当 padding_mode 为 constant 时,设置填充的像素值 (默认为0)

4)其他更多图像变换操作

其他更多的图像变换操作,看这里吧


3.2 对 Tensor 的常见操作

1)归一化 Normalize()

作用: 用均值和标准差对张量图像进行归一化,
公式: i m a g e = ( i m a g e − m e a n ) / s t d image = (image-mean) / std image=(imagemean)/std

比如,原像素值的取值区间为 [0, 1],在使用 transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) 进行归一化后,原像素值被分布到了 [-1, 1] 区间:

  • 原来的 0~1 最小值 0 则变成 (0 - 0.5) / 0.5 = -1
  • 最大值1则变成 (1 - 0.5) / 0.5 = 1

其中 mean 和 std 的3个值分表表示图像的3个通道
如果是单通道的灰度图,可以写成 transforms.Normalize(mean=[0.5], std=[0.5])

我们可能会看到很多代码里面是这样的:
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
这一组值是怎么来的呢?答案就是通过数据集,提前抽样计算出来的


2)转换为图像 ToPILImage()

将 Tensor类型数据转换为图片数据 PILImage, torchvision.transforms.ToPILImage() 函数的作用是把Tensor数据变为完整原始的图片数据(保存后可以直接双击打开的那种)
其内部处理过程为:

  • 将Tensor的每个元素乘以255
  • 将数据由Tensor转化成Uint8
  • 将Tensor转化成numpy的ndarray类型
  • 对ndarray对象做permute (1, 2, 0)的转置,将shape 由 (C, H, W) 转为shape为(H,W, C)
  • 将ndarray对象转化成PILImage数据格式
  • 输出该PILImage数据(save后可以直接打开)

4、torchvision.utils

4.1 图像拼接 grid

一行最多展示8张图片

import torch
import torchvision
from torchvision import transforms, utils
from torch.utils import data
import matplotlib.pyplot as plt

trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('./data', transform=trans)
train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)

for (img, label) in train_loader:
    grid = utils.make_grid(img)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.show()
    break

在这里插入图片描述


4.2 tensor存储为图片 save_img

torchvision.utils.save_img(img, path)

image 的数据类型是tensor

本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到