datasets.ImageFolder和train_dataset.class_to_idx的用法

发布于:2024-03-06 ⋅ 阅读:(77) ⋅ 点赞:(0)

datasets.ImageFolder用法,是将文件夹的名字转化为标签。用于分类任务。

from torchvision import  datasets

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])

 比如在flower_photos文件夹下存放着五个子文件夹,分别存放着各种类别的图像。

    |-- flower_photos
        |-- daisy
        |-- dandelion
        |-- roses
        |-- sunflowers
        |-- tulips

 使用datasets.ImageFolder后就可以将daisy转化为0,dandelion为1......。

 flower_list = train_dataset.class_to_idx

 这行代码可以获取数据的类别数以及对应的类别标签。以字典的形式保存。

 输出为:

 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}

然后就可以使用torch.utils.data.DataLoader加载了。

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)

 

 完整代码

可以先将数据集划分好以下各市

    |-- train
        |-- daisy
        |-- dandelion
        |-- roses
        |-- sunflowers
        |-- tulips
    |-- val
        |-- daisy
        |-- dandelion
        |-- roses
        |-- sunflowers
        |-- tulips

 然后就可以进行以下加载了

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join( "/kaggle/working/", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))