一、导入第三方库
import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
二、手写数据集准备



#数据集类
class MNISTDataset(Dataset):
def __init__(self,files,root_dir,transform=None):
self.files=files
self.root_dir=root_dir
self.transform=transform
self.labels=[]
for f in files:
parts=f.split("_")
p=parts[2].split(".")[0]
self.labels.append(int(p))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img_path=os.path.join(self.root_dir,self.files[idx])
img=Image.open(img_path).convert('L')
if self.transform:
img=self.transform(img)
label=self.labels[idx]
return img,label
三、CNN模型的pytorch实现
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(10, 20, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc = nn.Sequential(
nn.Linear(320, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
def forward(self, x):
batch_size=x.size(0)
x=self.conv1(x)
x=self.conv2(x)
x=x.view(batch_size, -1)
x=self.fc(x)
return x
四、主程序
if __name__ == '__main__':
#路径
base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
train_dir=os.path.join(base_dir,"minist_train")
test_dir=os.path.join(base_dir,"minist_test")
#获取文件夹里图像的名称
train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]
test_files=[f for f in os.listdir(test_dir) if f.endswith('.jpg')]
#数据转换
transform=transforms.Compose([
transforms.Resize((28, 28)), #统一尺寸
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])
])
#创建数据集和数据加载器
train_dataset=MNISTDataset(train_files,train_dir,transform=transform)
test_dataset=MNISTDataset(test_files,test_dir,transform=transform)
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)
model=CNN()
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
#训练函数
def train_cnn(epoch):
model.train()
train_loss = []
for epoch_idx in range(epoch):
running_loss=0.0
for batch_idx,(data, target) in enumerate(train_loader):
optimizer.zero_grad()
output=model(data)
loss=criterion(output,target)
loss.backward()
optimizer.step()
running_loss+=loss.item()
if batch_idx%100==0:
print(f'Epoch: {epoch_idx + 1}, Batch: {batch_idx}, Loss: {loss.item():.6f}')
avg_loss=running_loss/len(train_loader)
train_loss.append(avg_loss)
print(f'Epoch {epoch_idx + 1}/{epoch}, Average Loss: {avg_loss:.6f}')
#损失函数值曲线图
plt.figure(figsize=(12, 6))
plt.plot(train_loss)
plt.title("训练过程中损失函数值变化")
plt.xlabel("Epoch")
plt.ylabel("损失函数值")
plt.grid()
#保存
loss_plot_path=os.path.join(base_dir,"training_loss_curve.jpg")
plt.savefig(loss_plot_path,dpi=300,bbox_inches='tight')
plt.close()
#对测试集
def test_cnn():
model.eval()
correct=0
total=0
with torch.no_grad():
for data,target in test_loader:
outputs=model(data)
_, predicted=torch.max(outputs.data, 1)
total+=target.size(0)
correct+=(predicted==target).sum().item()
accuracy=100*correct/total
print(f'测试集准确率: {accuracy:.2f}%')
return accuracy
#训练和测试
epoch=10
train_cnn(epoch)
test_accuracy=test_cnn()
#显示测试集第一张图像的预测结果
model.eval() #进入评估阶段
with torch.no_grad():
test_img,test_label=test_dataset[0]
output=model(test_img.unsqueeze(0)) # 添加批次维度
_,pred=torch.max(output.data, 1)
plt.imshow(test_img.squeeze(), cmap='gray')
plt.title(f"真实数字: {test_label}, 预测数字: {pred.item()}")
plt.axis('off')
pred_plot_path=os.path.join(base_dir,"first_test_pred.jpg")
plt.savefig(pred_plot_path,dpi=300,bbox_inches='tight')
plt.close()
五、运行结果
5.1 损失函数曲线图

5.2 测试集第一张图像的预测结果
