一、任务描述
从手写数字图像中自动识别出对应的数字(0-9)” 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)
1、任务的核心定义:输入与输出
- 输入:28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 “3” 的形状,就是模型的输入。
- 输出:一个 “类别标签”,即从 10 个可能的类别(0、1、2、…、9)中选择一个,作为输入图像对应的数字,例如:输入 “3” 的图像,模型输出 “类别 3”,即完成一次正确识别。
- 目标:让模型在 “未见的手写数字图像” 上,尽可能准确地输出正确类别(通常用 “准确率” 衡量,即正确识别的图像数 / 总图像数)
2、任务的核心挑战
- 不同人书写习惯差异极大:有人写的 “4” 带弯钩,有人写的 “7” 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 “5”,可能是 “直笔 5”“圆笔 5”,也可能是倾斜 10° 或 20° 的 “5”—— 模型需要忽略这些 “风格差异”,抓住 “数字的本质特征”(如 “5 有一个上半圆 + 一个竖线”)。
- 图像噪声与干扰:手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 “0” 的图像,边缘有一小块污渍,模型需要判断 “这是噪声” 而不是 “0 的一部分”,避免误判为 “6” 或 “8”。
二、模型训练
1、MNIST数据集
MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 “基准数据集”,MNIST手写数字识别的核心是 “让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字”,它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。
- 数据量适中:包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
- 图像规格统一:所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
- 标注准确:每张图像都有明确的 “正确数字标签”(人工标注),无需额外标注成本。
2、代码
- 数据准备:使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;
- 模型定义:定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);
- 训练设置:使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;
- 训练过程:循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据准备
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(
root='./data', # 数据保存路径
train=True, # 训练集
download=True, # 如果数据不存在则下载
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False, # 测试集
download=True,
transform=transform
)
# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 2. 定义模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
# 输入层到隐藏层
self.fc1 = nn.Linear(28 * 28, 128) # MNIST图像大小为28x28
# 隐藏层到输出层
self.fc2 = nn.Linear(128, 10) # 10个类别(0-9)
def forward(self, x):
# 将图像展平为一维向量
x = x.view(-1, 28 * 28)
# 隐藏层,使用ReLU激活函数
x = torch.relu(self.fc1(x))
# 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)
x = self.fc2(x)
return x
# 3. 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):
model.train() # 设置为训练模式
train_losses = []
for epoch in range(epochs):
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(data)
loss = criterion(outputs, target)
# 反向传播和优化
loss.backward()
optimizer.step()
running_loss += loss.item()
# 每100个批次打印一次信息
if batch_idx % 100 == 99:
print(
f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
running_loss = 0.0
train_losses.append(running_loss / len(train_loader))
return train_losses
# 6. 运行训练和测试
if __name__ == '__main__':
# 训练模型
print("开始训练模型...")
train_losses = train(model, train_loader, criterion, optimizer, epochs=5)
print("模型训练完成...")
# 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')
print("模型已保存为 mnist_model.pth")
三、模型使用测试
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms # 修正transforms的导入方式
# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载模型
def load_model(model_path='mnist_model.pth'):
model = SimpleNN()
# 加载模型时添加参数以避免潜在的Python 3兼容性问题
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
model.eval() # 设置为评估模式
return model
# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):
# 打开图像并转换为灰度图
img = Image.open(image_path).convert('L') # 'L'表示灰度模式
# 调整大小为28x28
img = img.resize((28, 28))
# 转换为numpy数组并归一化
img_array = np.array(img) / 255.0
# 定义图像转换(使用torchvision的transforms)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 注意:这里需要先将numpy数组转换为PIL图像再应用transform
img_pil = Image.fromarray((img_array * 255).astype(np.uint8))
img_tensor = transform(img_pil).unsqueeze(0) # 增加批次维度
return img_tensor
# 预测函数
def predict_digit(model, image_path):
# 预处理图像
img_tensor = preprocess_image(image_path)
# 预测
with torch.no_grad(): # 不计算梯度
outputs = model(img_tensor)
_, predicted = torch.max(outputs.data, 1)
return predicted.item() # 返回预测的数字
# 示例使用
if __name__ == '__main__':
# 加载模型
model = load_model('mnist_model.pth')
# 预测示例图像
test_image_path = 'test_digit.png' # 用户需要提供的测试图像路径
try:
predicted_digit = predict_digit(model, test_image_path)
print(f"预测的数字是: {predicted_digit}")
except Exception as e:
print(f"预测出错: {str(e)}")
使用gpu0(第一块gpu)进行训练/推理:
torch.cuda.set_device(0)
model = model.cuda(0)
使用cpu记性训练/推理:
model = model.cpu()
怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制