第一步:准备数据
mnist开源数据集
第二步:搭建模型
我们这里搭建了一个LeNet5网络
参考代码如下:
import torch
from torch import nn
class Reshape(nn.Module):
def forward(self, x):
return x.view(-1, 1, 28, 28)
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.net = nn.Sequential(
Reshape(),
# CONV1, ReLU1, POOL1
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
# nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
# CONV2, ReLU2, POOL2
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
# FC1
nn.Linear(in_features=16 * 5 * 5, out_features=120),
nn.ReLU(),
# FC2
nn.Linear(in_features=120, out_features=84),
nn.ReLU(),
# FC3
nn.Linear(in_features=84, out_features=10)
)
# 添加softmax层
self.softmax = nn.Softmax()
def forward(self, x):
logits = self.net(x)
# 将logits转为概率
prob = self.softmax(logits)
return prob
if __name__ == '__main__':
model = LeNet5()
X = torch.rand(size=(256, 1, 28, 28), dtype=torch.float32)
for layer in model.net:
X = layer(X)
print(layer.__class__.__name__, '\toutput shape: \t', X.shape)
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
print(model(X))
第三步:训练代码
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from model import LeNet5
# DATASET
train_data = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=ToTensor()
)
test_data = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=ToTensor()
)
# PREPROCESS
batch_size = 256
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)
for X, y in train_dataloader:
print(X.shape) # torch.Size([256, 1, 28, 28])
print(y.shape) # torch.Size([256])
break
# MODEL
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LeNet5().to(device)
# TRAIN MODEL
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())
def train(dataloader, model, loss_func, optimizer, epoch):
model.train()
data_size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
y_hat = model(X)
loss = loss_func(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss, current = loss.item(), batch * len(X)
print(f'EPOCH{epoch+1}\tloss: {loss:>7f}', end='\t')
# Test model
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f'Test Error: Accuracy: {(100 * correct):>0.1f}%, Average loss: {test_loss:>8f}\n')
if __name__ == '__main__':
epoches = 80
for epoch in range(epoches):
train(train_dataloader, model, loss_func, optimizer, epoch)
test(test_dataloader, model, loss_func)
# Save models
torch.save(model.state_dict(), 'model.pth')
print('Saved PyTorch LeNet5 State to model.pth')
第四步:统计训练过程
EPOCH1 loss: 1.908403 Test Error: Accuracy: 58.3%, Average loss: 1.943602
EPOCH2 loss: 1.776060 Test Error: Accuracy: 72.2%, Average loss: 1.750917
EPOCH3 loss: 1.717706 Test Error: Accuracy: 73.6%, Average loss: 1.730332
EPOCH4 loss: 1.719344 Test Error: Accuracy: 76.0%, Average loss: 1.703456
EPOCH5 loss: 1.659312 Test Error: Accuracy: 76.6%, Average loss: 1.694500
EPOCH6 loss: 1.647946 Test Error: Accuracy: 76.9%, Average loss: 1.691286
EPOCH7 loss: 1.653712 Test Error: Accuracy: 77.0%, Average loss: 1.690819
EPOCH8 loss: 1.653270 Test Error: Accuracy: 76.8%, Average loss: 1.692459
EPOCH9 loss: 1.649021 Test Error: Accuracy: 77.5%, Average loss: 1.686158
EPOCH10 loss: 1.648204 Test Error: Accuracy: 78.3%, Average loss: 1.678802
EPOCH11 loss: 1.647159 Test Error: Accuracy: 78.4%, Average loss: 1.676133
EPOCH12 loss: 1.647390 Test Error: Accuracy: 78.6%, Average loss: 1.674455
EPOCH13 loss: 1.646807 Test Error: Accuracy: 78.4%, Average loss: 1.675752
EPOCH14 loss: 1.630824 Test Error: Accuracy: 79.1%, Average loss: 1.668470
EPOCH15 loss: 1.524222 Test Error: Accuracy: 86.3%, Average loss: 1.599240
EPOCH16 loss: 1.524022 Test Error: Accuracy: 86.7%, Average loss: 1.594947
EPOCH17 loss: 1.524296 Test Error: Accuracy: 87.1%, Average loss: 1.588946
EPOCH18 loss: 1.523599 Test Error: Accuracy: 87.3%, Average loss: 1.588275
EPOCH19 loss: 1.523655 Test Error: Accuracy: 87.5%, Average loss: 1.586576
EPOCH20 loss: 1.523659 Test Error: Accuracy: 88.2%, Average loss: 1.579286
EPOCH21 loss: 1.523733 Test Error: Accuracy: 87.9%, Average loss: 1.582472
EPOCH22 loss: 1.523748 Test Error: Accuracy: 88.2%, Average loss: 1.578699
EPOCH23 loss: 1.523788 Test Error: Accuracy: 88.0%, Average loss: 1.579700
EPOCH24 loss: 1.523708 Test Error: Accuracy: 88.1%, Average loss: 1.579758
EPOCH25 loss: 1.523683 Test Error: Accuracy: 88.4%, Average loss: 1.575913
EPOCH26 loss: 1.523646 Test Error: Accuracy: 88.7%, Average loss: 1.572831
EPOCH27 loss: 1.523654 Test Error: Accuracy: 88.9%, Average loss: 1.570528
EPOCH28 loss: 1.523642 Test Error: Accuracy: 89.0%, Average loss: 1.570223
EPOCH29 loss: 1.523663 Test Error: Accuracy: 89.0%, Average loss: 1.570385
EPOCH30 loss: 1.523658 Test Error: Accuracy: 88.9%, Average loss: 1.571195
EPOCH31 loss: 1.523653 Test Error: Accuracy: 88.4%, Average loss: 1.575981
EPOCH32 loss: 1.523653 Test Error: Accuracy: 89.0%, Average loss: 1.570087
EPOCH33 loss: 1.523642 Test Error: Accuracy: 88.9%, Average loss: 1.571018
EPOCH34 loss: 1.523649 Test Error: Accuracy: 89.0%, Average loss: 1.570439
EPOCH35 loss: 1.523629 Test Error: Accuracy: 90.4%, Average loss: 1.555473
EPOCH36 loss: 1.461187 Test Error: Accuracy: 97.1%, Average loss: 1.491042
EPOCH37 loss: 1.461230 Test Error: Accuracy: 97.7%, Average loss: 1.485049
EPOCH38 loss: 1.461184 Test Error: Accuracy: 97.7%, Average loss: 1.485653
EPOCH39 loss: 1.461156 Test Error: Accuracy: 98.2%, Average loss: 1.479966
EPOCH40 loss: 1.461335 Test Error: Accuracy: 98.2%, Average loss: 1.479197
EPOCH41 loss: 1.461152 Test Error: Accuracy: 98.7%, Average loss: 1.475477
EPOCH42 loss: 1.461153 Test Error: Accuracy: 98.7%, Average loss: 1.475124
EPOCH43 loss: 1.461153 Test Error: Accuracy: 98.9%, Average loss: 1.472885
EPOCH44 loss: 1.461151 Test Error: Accuracy: 99.1%, Average loss: 1.470957
EPOCH45 loss: 1.461156 Test Error: Accuracy: 99.1%, Average loss: 1.471141
EPOCH46 loss: 1.461152 Test Error: Accuracy: 99.1%, Average loss: 1.470793
EPOCH47 loss: 1.461151 Test Error: Accuracy: 98.8%, Average loss: 1.474548
EPOCH48 loss: 1.461151 Test Error: Accuracy: 99.1%, Average loss: 1.470666
EPOCH49 loss: 1.461151 Test Error: Accuracy: 99.1%, Average loss: 1.471546
EPOCH50 loss: 1.461151 Test Error: Accuracy: 99.0%, Average loss: 1.471407
EPOCH51 loss: 1.461151 Test Error: Accuracy: 98.8%, Average loss: 1.473795
EPOCH52 loss: 1.461164 Test Error: Accuracy: 98.2%, Average loss: 1.480009
EPOCH53 loss: 1.461151 Test Error: Accuracy: 99.2%, Average loss: 1.469931
EPOCH54 loss: 1.461152 Test Error: Accuracy: 99.2%, Average loss: 1.469916
EPOCH55 loss: 1.461151 Test Error: Accuracy: 98.9%, Average loss: 1.472574
EPOCH56 loss: 1.461151 Test Error: Accuracy: 98.6%, Average loss: 1.476035
EPOCH57 loss: 1.461151 Test Error: Accuracy: 98.2%, Average loss: 1.478933
EPOCH58 loss: 1.461150 Test Error: Accuracy: 99.4%, Average loss: 1.468186
EPOCH59 loss: 1.461151 Test Error: Accuracy: 99.4%, Average loss: 1.467602
EPOCH60 loss: 1.461151 Test Error: Accuracy: 99.1%, Average loss: 1.471206
EPOCH61 loss: 1.461151 Test Error: Accuracy: 98.8%, Average loss: 1.473356
EPOCH62 loss: 1.461151 Test Error: Accuracy: 99.2%, Average loss: 1.470242
EPOCH63 loss: 1.461150 Test Error: Accuracy: 99.1%, Average loss: 1.470826
EPOCH64 loss: 1.461151 Test Error: Accuracy: 98.7%, Average loss: 1.474476
EPOCH65 loss: 1.461150 Test Error: Accuracy: 99.3%, Average loss: 1.469116
EPOCH66 loss: 1.461150 Test Error: Accuracy: 99.4%, Average loss: 1.467823
EPOCH67 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466486
EPOCH68 loss: 1.461152 Test Error: Accuracy: 99.3%, Average loss: 1.468688
EPOCH69 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466256
EPOCH70 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466588
EPOCH71 loss: 1.461150 Test Error: Accuracy: 99.6%, Average loss: 1.465280
EPOCH72 loss: 1.461150 Test Error: Accuracy: 99.4%, Average loss: 1.467110
EPOCH73 loss: 1.461151 Test Error: Accuracy: 99.6%, Average loss: 1.465245
EPOCH74 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466551
EPOCH75 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466001
EPOCH76 loss: 1.461150 Test Error: Accuracy: 99.3%, Average loss: 1.468074
EPOCH77 loss: 1.461151 Test Error: Accuracy: 99.6%, Average loss: 1.465709
EPOCH78 loss: 1.461150 Test Error: Accuracy: 99.5%, Average loss: 1.466567
EPOCH79 loss: 1.461150 Test Error: Accuracy: 99.6%, Average loss: 1.464922
EPOCH80 loss: 1.461150 Test Error: Accuracy: 99.6%, Average loss: 1.465109
第五步:搭建GUI界面
第六步:整个工程的内容
有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码,主要使用方法可以参考里面的“文档说明_必看.docx”
代码的下载路径(新窗口打开链接):基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)
有问题可以私信或者留言,有问必答