datasets.ImageFolder用法,是将文件夹的名字转化为标签。用于分类任务。
from torchvision import datasets
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
比如在flower_photos文件夹下存放着五个子文件夹,分别存放着各种类别的图像。
|-- flower_photos
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
使用datasets.ImageFolder后就可以将daisy转化为0,dandelion为1......。
flower_list = train_dataset.class_to_idx
这行代码可以获取数据的类别数以及对应的类别标签。以字典的形式保存。
输出为:
{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
然后就可以使用torch.utils.data.DataLoader加载了。
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
完整代码
可以先将数据集划分好以下各市
|-- train
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
|-- val
|-- daisy
|-- dandelion
|-- roses
|-- sunflowers
|-- tulips
然后就可以进行以下加载了
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = os.path.join( "/kaggle/working/", "flower_data") # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))