完整的模型验证(测试)套路
CIFAR10测试集分类

单个图片测试代码:
import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from PIL import Image
image_path = './image/jinx.jpg'
image = Image.open(image_path)
image = image.convert('RGB')
print(image)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)
# 创建网络模型
class Chenxi(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10),
)
def forward(self, x):
x = self.model1(x)
return x
model = torch.load('chenxi_0.pth') # 调用自己训练好的模型
# model = torch.load('chenxi_0.pth', map_location=torch.device('cpu'))
print(model)
image = torch.reshape(image, (1, 3, 32, 32)) # 关于batch_size
image = image.cuda()
model.eval()
with torch.no_grad():
output = model(image)
print(output)
# tensor([[-0.7223, 0.4807, -0.1583, 0.2034, -0.0316, 0.4585, 0.3231, 0.1076,
# -0.4046, 0.2497]], device='cuda:0')
print(output.argmax(1)) # 转换输出类型
# tensor([1], device='cuda:0')