目录
1、数据集torch.utils.data.Dataset
pytorch导入数据集主要依靠Dataset类。其中Dataset类是一个抽象类,在构建我们自己的MyDataset类时需要继承Dataset类。并重写(overwrite)它的两个方法:__getitem__和__len__。
方法 | 说明 |
__getitem__(self, index) | 根据传入的index参数返回列表中对应下标的一条数据(包括图片和标签)。 |
__len__(self) | 返回数据集的长度。 |
(1)实现MyDataset类
from torch.utils.data import Dataset
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root_dir, label):
self.root_dir = root_dir # 图片所在的根目录路径
self.label = label # 图片的标签
# 图片所在的根目录路径下所有图片的名称,
# 在__getitem__方法中会通过图片名称来获取图片的路径。
self.image_name = os.listdir(self.root_dir)
"""重写__getitem__方法"""
def __getitem__(self, index):
# 获取完整图片路径
image_path = os.path.join(self.root_dir, self.image_name[index])
# 读取图片
image = Image.open(image_path)
label = self.label
# 返回对应下标的图片和label
return image, label
"""重写__len__方法"""
def __len__(self):
# 返回数据集长度
return len(self.image_name)
(2)测试(测试数据集下载(蚂蚁蜜蜂数据集))
from torch.utils.data import Dataset
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root_dir, label):
self.root_dir = root_dir # 图片所在的根目录路径
self.label = label # 图片的标签
# 图片所在的根目录路径下所有图片的名称,
# 在__getitem__方法中会通过图片名称来获取图片的路径。
self.image_name = os.listdir(self.root_dir)
"""重写__getitem__方法"""
def __getitem__(self, index):
# 获取完整图片路径
image_path = os.path.join(self.root_dir, self.image_name[index])
# 读取图片
image = Image.open(image_path)
label = self.label
# 返回对应下标的图片和label
return image, label
"""重写__len__方法"""
def __len__(self):
# 返回数据集长度
return len(self.image_name)
"""采用蚂蚁蜜蜂数据集进行测试"""
root_dir = "data/dataset/train/bees"
label = "bees"
bees_dataset = MyDataset(root_dir=root_dir, label=label)
image, label = bees_dataset[6]
image.show()
(3)获取第6条数据并显示