PyTorch深度学习快速入门学习总结(四)

发布于:2025-08-06 ⋅ 阅读:(14) ⋅ 点赞:(0)

完整的模型验证(测试)套路

CIFAR10测试集分类

单个图片测试代码:

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from PIL import Image


image_path = './image/jinx.jpg'
image = Image.open(image_path)
image = image.convert('RGB')
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)

# 创建网络模型
class Chenxi(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10),
        )
    def forward(self, x):
        x = self.model1(x)
        return x

model = torch.load('chenxi_0.pth') # 调用自己训练好的模型
# model = torch.load('chenxi_0.pth', map_location=torch.device('cpu'))
print(model)

image = torch.reshape(image, (1, 3, 32, 32)) # 关于batch_size

image = image.cuda()

model.eval()
with torch.no_grad():
    output = model(image)
print(output)
# tensor([[-0.7223,  0.4807, -0.1583,  0.2034, -0.0316,  0.4585,  0.3231,  0.1076,
         # -0.4046,  0.2497]], device='cuda:0')

print(output.argmax(1)) # 转换输出类型
# tensor([1], device='cuda:0')

网站公告

今日签到

点亮在社区的每一天
去签到