# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms
# Image display
import matplotlib.pyplot as plt
import numpy as np
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
# Gather datasets and prepare them for consumption
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Store separate training and validations splits in ./data
training_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=True,
transform=transform)
validation_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
training_loader = torch.utils.data.DataLoader(training_set,
batch_size=4,
shuffle=True,
num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set,
batch_size=4,
shuffle=False,
num_workers=2)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
这段代码的作用是加载并预处理 FashionMNIST 数据集,FashionMNIST 是一个包含 28x28 像素灰度图像的服装数据集,用于图像分类任务。我们来逐步解释代码:
1. 数据预处理(Transformations)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
transforms.Compose
:是一个工具,可以将多个操作按顺序组合在一起,对数据进行一系列处理。transforms.ToTensor()
:这一步将图像转换成 PyTorch 中可以使用的 张量(Tensor) 格式。并且把像素值从[0, 255]
范围变成[0, 1]
之间。transforms.Normalize((0.5,), (0.5,))
:这一步对图像做归一化处理,把像素值的均值调整为 0.5,标准差调整为 0.5。这样可以帮助神经网络更快地收敛,提高训练效果。
2. 加载数据集(Dataset)
training_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=True,
transform=transform)
validation_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
这里使用了
torchvision.datasets.FashionMNIST
来加载 FashionMNIST 数据集。./data
:指定了数据集保存的目录。如果没有下载,程序会自动下载到这个目录。download=True
:表示如果本地没有数据集,就自动从网上下载。train=True
:表示加载的是 训练集。train=False
:表示加载的是 测试集(验证集)。transform=transform
:表示在加载数据时,应用之前定义的预处理操作(转为张量并归一化)。
3. 数据加载器(DataLoader)
training_loader = torch.utils.data.DataLoader(training_set,
batch_size=4,
shuffle=True,
num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set,
batch_size=4,
shuffle=False,
num_workers=2)
DataLoader
是一个用来批量加载数据的工具,这样可以提高训练时的效率。batch_size=4
:每次加载 4 张图片,这些图片会组成一个批次(batch)。shuffle=True
:表示在训练时,数据会被随机打乱,这样有助于防止模型记住数据的顺序,提高训练效果。shuffle=False
:在验证时,不需要打乱数据,保持数据顺序。num_workers=2
:表示有 2 个子进程 用来并行加载数据,这样可以加快数据读取的速度。
4. 类标签(Classes)
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
这是一个包含 10 个类别 的列表,每个类别对应着一种服装类型。每张图片的标签就表示它是哪一类服装:
0
代表 T恤/上衣1
代表 裤子2
代表 毛衣3
代表 裙子4
代表 外套5
代表 凉鞋6
代表 衬衫7
代表 运动鞋8
代表 包9
代表 高筒靴
总结
这段代码的主要目的是:
加载 FashionMNIST 数据集,并将其分为训练集和验证集。
对数据进行 预处理,将图像转换为 PyTorch 张量,并进行标准化。
使用 DataLoader 批量加载数据,方便训练时一次性输入多个图像。
定义了图像的 类别标签,这些标签用于后续训练和评估模型。
简而言之,这段代码准备了数据,使得我们可以直接拿来训练一个图像分类模型,自动识别不同类型的服装。
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# Extract a batch of 4 images
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
# Default log_dir argument is "runs" - but it's good to be specific
# torch.utils.tensorboard.SummaryWriter is imported above
writer = SummaryWriter('runs/fashion_mnist_experiment_1')
# Write image data to TensorBoard log dir
writer.add_image('Four Fashion-MNIST Images', img_grid)
writer.flush()
# To view, start TensorBoard on the command line with:
# tensorboard --logdir=runs
# ...and open a browser tab to http://localhost:6006/
这段代码是用来从训练集中取出一批图像,并将它们可视化显示出来的,适用于调试或初步查看训练数据。下面我们用通俗的中文逐行解释:
✅ 函数:显示图片的辅助函数
def matplotlib_imshow(img, one_channel=False):
这个函数叫做 matplotlib_imshow
,是一个用于显示图像的辅助工具。你传入一张图,它就会用 matplotlib
来画出来。
if one_channel:
img = img.mean(dim=0)
如果你传入的 one_channel=True
,说明你希望把图像变成灰度图(单通道),所以这一步就会把图像的颜色通道取平均,变成灰度图。
img = img / 2 + 0.5 # unnormalize
这一步是“反归一化”,因为之前图像被归一化成了 [-1, 1]
之间(即 Normalize((0.5,), (0.5,))
),现在要显示出来,就把它还原成 [0, 1]
范围的像素值:
原始像素 = 归一化像素 × 0.5 + 0.5
npimg = img.numpy()
将图像从 PyTorch 的 Tensor 转换成 NumPy 数组,因为 matplotlib
是用 NumPy 数据绘图的。
if one_channel:
plt.imshow(npimg, cmap="Greys")
如果是灰度图,就用灰色的色图(colormap)来显示图像。
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
如果是彩色图,需要把通道维度(channel)从 [C, H, W]
转换成 [H, W, C]
,因为 matplotlib
要求的是这种格式(高度、宽度、颜色通道)。
✅ 提取一批图像(batch)
dataiter = iter(training_loader)
images, labels = dataiter.next()
这两行代码的意思是:
从训练数据加载器
training_loader
中取出一个批次的数据(batch size = 4),并分为图像和标签两部分。images
就是形状为[4, 1, 28, 28]
的张量,表示 4 张 28×28 的灰度图像。
✅ 创建图像网格并显示
img_grid = torchvision.utils.make_grid(images)
将这 4 张图像拼成一个图像网格(横着排列),返回一个可以显示的图像张量。
matplotlib_imshow(img_grid, one_channel=True)
调用刚才定义的函数 matplotlib_imshow()
,将这个图像网格以灰度图的形式显示出来。
📌 总结一句话:
这段代码从训练集中随机取出 4 张图像,拼成一张图像网格,并使用 matplotlib
显示出来,方便我们快速了解训练数据的样子。
# Select a random subset of data and corresponding labels
def select_n_random(data, labels, n=100):
assert len(data) == len(labels)
perm = torch.randperm(len(data))
return data[perm][:n], labels[perm][:n]
# Extract a random subset of data
images, labels = select_n_random(training_set.data, training_set.targets)
# get the class labels for each image
class_labels = [classes[label] for label in labels]
# log embeddings
features = images.view(-1, 28 * 28)
writer.add_embedding(features,
metadata=class_labels,
label_img=images.unsqueeze(1))
writer.flush()
writer.close()
好的,我来用通俗易懂的中文一步步详细解释这段代码的意思:
1. 选择一小部分数据
def select_n_random(data, labels, n=100):
assert len(data) == len(labels)
perm = torch.randperm(len(data))
return data[perm][:n], labels[perm][:n]
这段代码定义了一个函数 select_n_random
,它的作用是:
随机从一大堆数据(
data
)和对应的标签(labels
)里面,挑选出 n 个样本(默认是 100 个)。torch.randperm(len(data))
是生成一个随机排列,比如数据有 10000张图片,它就打乱这10000个顺序,然后从中取前 n 个。注意:
assert len(data) == len(labels)
是确保图片和标签数量一一对应,不然就出错。
简单理解: 👉 就像从一堆扑克牌中随机抽出 100 张牌一样。
2. 提取随机子集
images, labels = select_n_random(training_set.data, training_set.targets)
这一行用上面定义的函数,从你的训练集 (training_set
) 中随机挑了 100 张图片及对应的标签。
images
保存了随机抽到的图片labels
保存了这些图片对应的正确类别
3. 拿到每张图片的文字标签
class_labels = [classes[label] for label in labels]
这里把数字标签(比如0、1、2)转换成了文字标签(比如 "T-shirt", "Trouser")。
classes
是一个列表,比如:classes = ['T-shirt', 'Trouser', 'Pullover', ...]
简单理解: 👉 就是把数字变成更好懂的中文/英文类别名。
4. 准备好数据做可视化(特征降维)
features = images.view(-1, 28 * 28)
把每张图片(原来是 28×28 的二维小图片)拉平成一行 784 个数字,因为后面要把这些数字送给 TensorBoard 画图。
简单理解: 👉 把小方块图片拉成长长的一条数据线。
5. 写入到 TensorBoard 的 Embedding
writer.add_embedding(features,
metadata=class_labels,
label_img=images.unsqueeze(1))
把这些图片的特征、对应的文字标签、以及原始图片,全部写入到 TensorBoard。
metadata=class_labels
:就是告诉 TensorBoard,这个点对应的是什么类别。label_img=images.unsqueeze(1)
:把图片加一个通道数(变成 1 通道),符合 TensorBoard 要求的格式。
简单理解: 👉 把这些图片的数据、名字和图片本身,全部打包进 TensorBoard,方便后面可视化查看。
6. 刷新并关闭文件
writer.flush()
writer.close()
flush()
是确保所有数据被保存到磁盘,不然可能还有东西留在内存里没写完。close()
是关闭日志文件,结束写入。
总结一下通俗版流程:
随机选 100张训练图片。
拿到图片的文字类别。
拉平成一行数据。
写到 TensorBoard,方便后面用图形界面直观地看。
最终你可以在 TensorBoard 上看到:
每一张图片在空间中的分布(像小点点一样)
点点上可以标注类别名字,甚至直接看到小图片!