引言
本次准备建立一个卷积神经网络模型,用于区分鸟和飞机,并从CIFAR-10数据集中选出所有鸟和飞机作为本次的数据集。
以此为例,介绍一个神经网络模型从数据集准备、数据归一化处理、模型网络函数定义、模型训练、结果验证、模型文件保存,端到端的模型全生命周期,方便大家深入了解AI模型开发的全过程。
一、网络场景定义与数据集准备
1.1 数据集准备
本次我准备使用CIFAR10数据集,它是一个简单有趣的数据集,由60000张小RGB图片构成(32像素*32像素),每张图类别标签用1~10数字表示
%matplotlib inline
from matplotlib import pyplot as plt
from torchvision import datasets
data_path = '/content/sample_data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)
type(cifar10).__mro__
1.2 查看数据集类别示例
class_names = ['airplane', 'aotomobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
fig = plt.figure(figsize=(8, 3))
num_classes = 10
for i in range(num_classes):
ax = fig.add_subplot(2, 5 ,1 + i, xticks=[], yticks=[])
ax.set_title(class_names[i])
img = next(img for img, label in cifar10 if label == i)
plt.imshow(img)
plt.show()
1.2.1 输出单张图像类别及展示图片
img, label = cifar10[99]
img, label, class_names[label]
plt.imshow(img)
plt.show()
1.3 数据集Dataset变换
使用torchvision.transforms模块,将PIL图像变换为PyTorch张量,用于图像分类
1.3.1 将单张图像转换为张量,输出张量大小
from torchvision import transforms
to_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape
1.3.2 将CIFAR10数据集转换为张量
tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
tensor_cifar10.__len__()
1.4 数据归一化
使用transforms.Compose()将图像连接起来,在数据加载器中直接进行数据归一化和数据增强操作
使用transforms.Normalize(),计算数据集中每个通道的平均值和标准差,使每个通道的均值为0,标准差为1
imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3)
imgs.shape
1.4.1 计算每个通道的平均值(mean)
imgs.view(3, -1).mean(dim=1)
1.4.2 计算每个通道的标准差(stdev)
imgs.view(3, -1).std(dim=1)
1.4.3 使用transforms.Normailze()对数据集归一化
使每个数据集的通道的均值为0,标准差为1
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
二、使用nn.Module编写第一个识别鸟与飞机的网络模型
2.1 构建鸟与飞机的训练集和验证集
2.1.1 准备CIFAR10数据集
cifar10 = datasets.CIFAR10(
data_path, train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616))
]))
cifar10_val = datasets.CIFAR10(
data_path, train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616))
]))
2.1.2 构建CIFAR2-数据集
label_map = {0:0, 2:1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
for img, label in cifar10
if label in [0, 2]
]
cifar2.__len__()
2.1.3 构建CIFAR2-验证集
cifar2_val = [(img, label_map[label])
for img, label in cifar10_val
if label in [0, 2]
]
cifar2_val.__len__()
2.1.4 准备批处理图像
img, _ = cifar2[0]
plt.imshow(img.permute(1, 2, 0))
plt.show()
img
img.shape
2.2 编写第一个nn.Module子模块的网络定义
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.act1 = nn.Tanh()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
self.act2 = nn.Tanh()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(8 * 8 * 8, 32)
self.act3 = nn.Tanh()
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
out = self.pool1(self.act1(self.conv1(x)))
out = self.pool2(self.act2(self.conv2(out)))
out = out.view(-1, 8 * 8 * 8)
out = self.act3(self.fc1(out))
out = self.fc2(out)
return out
2.2.1 将网络模型实例化,并输出模型参数
model = Net()
numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list
2.3 使用函数式API,优化nn.Module网络函数定义
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
self.fc1 = nn.Linear(8 * 8 * 8, 32)
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
out = out.view(-1, 8 * 8 * 8)
out = torch.tanh(self.fc1(out))
out = self.fc2(out)
return out
model = Net()
model(img.unsqueeze(0))
2.4 定义网络模型的训练循环函数,并执行训练
import datetime
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
for epoch in range(1, n_epochs + 1):
loss_train = 0.0
for imgs, labels in train_loader:
outputs = model(imgs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train += loss.item()
if epoch == 1 or epoch %10 == 0:
print('{} Epoch {}, Training loss{}'.format(
datetime.datetime.now(), epoch, loss_train / len(train_loader)))
train_loader = torch.utils.data.DataLoader(cifar2, batch_size = 64, shuffle=True)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()
training_loop(
n_epochs = 100,
optimizer = optimizer,
model = model,
loss_fn = loss_fn,
train_loader = train_loader,
2.4.1 训练结果(耗时7分钟)
2025-08-17 15:13:20.123706 Epoch 1, Training loss0.5672952472024663
2025-08-17 15:14:01.667640 Epoch 10, Training loss0.32902660861516453
2025-08-17 15:14:47.187795 Epoch 20, Training loss0.2960508146863075
2025-08-17 15:15:33.119990 Epoch 30, Training loss0.26820498961172284
2025-08-17 15:16:19.303661 Epoch 40, Training loss0.24607981879050564
2025-08-17 15:17:04.858228 Epoch 50, Training loss0.22783752284042394
2025-08-17 15:17:50.712569 Epoch 60, Training loss0.2095268357806145
2025-08-17 15:18:36.846523 Epoch 70, Training loss0.19460647420328894
2025-08-17 15:19:22.404563 Epoch 80, Training loss0.18098321051639357
2025-08-17 15:20:08.067236 Epoch 90, Training loss0.16757476806735536
2025-08-17 15:20:54.041604 Epoch 100, Training loss0.15512346253273593
2.5 测量准确率(使用验证集)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
def validate(model, train_loader, val_loader):
for name, loader in [("train", train_loader), ("val", val_loader)]:
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in loader:
outputs = model(imgs)
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy {}: {:.2f}".format(name, correct/total))
validate(model, train_loader, val_loader)
2.6 保存并加载我们的模型
2.6.1 保存模型
torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')
torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')
2.6.2 模型pt文件生成
包含模型的所有参数,即2个卷积模块和2个线性模块的权重和偏置
2.6.3 加载参数到模型实例
loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt'))
三、小结
至此,我们完成一个卷积神经网络模型birds_vs_airplanes的构建,可用于图像分类识别,区分图片是鸟还是飞机,准确性高达94!
我们从数据集准备、数据集准备、数据归一化处理、模型网络函数定义、模型训练、结果验证、模型文件保存,并将模型参数加载到另一个新模型实例中,端到端完整串联一个神经网络模型全生命周期的过程,加深对AI模型开发的理解,这是个经典案例,快来试试吧~