CNN卷积神经网络预测手写数字的Pytorch实现

发布于:2025-08-14 ⋅ 阅读:(14) ⋅ 点赞:(0)

一、导入第三方库

import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

二、手写数据集准备

#数据集类
class MNISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert('L')

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

三、CNN模型的pytorch实现

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size, -1)
        x=self.fc(x)
        return x

四、主程序

if __name__ == '__main__':
    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")
    test_dir=os.path.join(base_dir,"minist_test")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]
    test_files=[f for f in os.listdir(test_dir) if f.endswith('.jpg')]

    #数据转换
    transform=transforms.Compose([
        transforms.Resize((28, 28)),  #统一尺寸
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    #创建数据集和数据加载器
    train_dataset=MNISTDataset(train_files,train_dir,transform=transform)
    test_dataset=MNISTDataset(test_files,test_dir,transform=transform)

    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
    test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)


    model=CNN()
    criterion=nn.CrossEntropyLoss()
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)


    #训练函数
    def train_cnn(epoch):
        model.train()
        train_loss = []

        for epoch_idx in range(epoch):
            running_loss=0.0
            for batch_idx,(data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output=model(data)
                loss=criterion(output,target)
                loss.backward()
                optimizer.step()

                running_loss+=loss.item()

                if batch_idx%100==0:
                    print(f'Epoch: {epoch_idx + 1}, Batch: {batch_idx}, Loss: {loss.item():.6f}')

            avg_loss=running_loss/len(train_loader)
            train_loss.append(avg_loss)
            print(f'Epoch {epoch_idx + 1}/{epoch}, Average Loss: {avg_loss:.6f}')

        #损失函数值曲线图
        plt.figure(figsize=(12, 6))
        plt.plot(train_loss)
        plt.title("训练过程中损失函数值变化")
        plt.xlabel("Epoch")
        plt.ylabel("损失函数值")
        plt.grid()

        #保存
        loss_plot_path=os.path.join(base_dir,"training_loss_curve.jpg")
        plt.savefig(loss_plot_path,dpi=300,bbox_inches='tight')
        plt.close()


    #对测试集
    def test_cnn():
        model.eval()
        correct=0
        total=0

        with torch.no_grad():
            for data,target in test_loader:
                outputs=model(data)
                _, predicted=torch.max(outputs.data, 1)
                total+=target.size(0)
                correct+=(predicted==target).sum().item()

        accuracy=100*correct/total
        print(f'测试集准确率: {accuracy:.2f}%')
        return accuracy


    #训练和测试
    epoch=10
    train_cnn(epoch)
    test_accuracy=test_cnn()

    #显示测试集第一张图像的预测结果
    model.eval()  #进入评估阶段
    with torch.no_grad():
        test_img,test_label=test_dataset[0]
        output=model(test_img.unsqueeze(0))  # 添加批次维度
        _,pred=torch.max(output.data, 1)

    plt.imshow(test_img.squeeze(), cmap='gray')
    plt.title(f"真实数字: {test_label}, 预测数字: {pred.item()}")
    plt.axis('off')

    pred_plot_path=os.path.join(base_dir,"first_test_pred.jpg")
    plt.savefig(pred_plot_path,dpi=300,bbox_inches='tight')
    plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 测试集第一张图像的预测结果


网站公告

今日签到

点亮在社区的每一天
去签到