小白的进阶之路系列之十二----人工智能从初步到精通pytorch综合运用的讲解第五部分

发布于:2025-06-07 ⋅ 阅读:(15) ⋅ 点赞:(0)

在本笔记本中,我们将针对Fashion-MNIST数据集训练LeNet-5的变体。Fashion-MNIST是一组描绘各种服装的图像瓦片,有十个类别标签表明所描绘的服装类型。

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

# In case you are using an environment that has TensorFlow installed,
# such as Google Colab, uncomment the following code to avoid
# a bug with saving embeddings to your TensorBoard directory

# import tensorflow as tf
# import tensorboard as tb
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

在TensorBoard中显示图像

让我们首先将数据集中的样本图像添加到TensorBoard:

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

if __name__ == '__main__':
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])

    # Store separate training and validations splits in ./data
    training_set = torchvision.datasets.FashionMNIST('./data',
        download=True,
        train=True,
        transform=transform)
    validation_set = torchvision.datasets.FashionMNIST('./data',
        download=True,
        train=False,
        transform=transform)

    training_loader = torch.utils.data.DataLoader(training_set,
                                                  batch_size=4,
                                                  shuffle=True,
                                                  num_workers=2)


    validation_loader = torch.utils.data.DataLoader(validation_set,
                                                    batch_size=4,
                                                    shuffle=False,
                                                    num_workers=2)

    # Class labels
    classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
            'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')



    # Extract a batch of 4 images
    dataiter = iter(training_loader)
    images, labels = next(dataiter)

    # Create a grid from the images and show them
    img_grid = torchvision.utils.make_grid(images)
    matplotlib_imshow(img_grid, one_channel=True)
    plt.show()

输出为:

在这里插入图片描述

上面,我们使用TorchVision和Matplotlib创建了一个小批量输入数据的视觉网格。下面,我们在SummaryWriter上使用add_image()调用来记录TensorBoard使用的图像,并且我们还调用flush())来确保它立即写入磁盘。

    # Default log_dir argument is "runs" - but it's good to be specific
    # torch.utils.tensorboard.SummaryWriter is imported above
    writer = SummaryWriter('runs/fashion_mnist_experiment_1')

    # Write image data to TensorBoard log dir
    writer.add_image('Four Fashion-MNIST Images', img_grid)
    writer.flush()

    # To view, start TensorBoard on the command line with:
    #   tensorboard --logdir=runs
    # ...and open a browser tab to http://localhost:6006/

如果您在命令行启动TensorBoard并在新的浏览器选项卡中打开它(通常在localhost:6006),您应该在IMAGES选项卡下看到图像网格。

绘制标量以可视化训练

TensorBoard对于跟踪您的训练进度和效果非常有用。下面,我们将运行一个训练循环,跟踪一些指标,并保存数据供TensorBoard使用。

让我们定义一个模型来对图像块进行分类,以及一个用于训练的优化器和损失函数:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 4 * 4, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(sel

网站公告

今日签到

点亮在社区的每一天
去签到