matlab从pytorch中导入LeNet-5网络框架

发布于:2025-04-05 ⋅ 阅读:(48) ⋅ 点赞:(0)

这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。

一、Pytorch的LeNet-5网络准备

根据LeNet-5的结构图,我们可以写如下结构

import torch
import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()

        self.feature_extractor = nn.Sequential(
            # C1: Conv(1→6), 输出 28x28 → 6x28x28
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            nn.BatchNorm2d(6),
            nn.ReLU(inplace=True),

            # S2: MaxPool 2x2, 输出 6x14x14
            nn.MaxPool2d(kernel_size=2, stride=2),

            # C3: Conv(6→16), 输出 16x10x10
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            # S4: MaxPool 2x2, 输出 16x5x5
            nn.MaxPool2d(kernel_size=2, stride=2),

            # C5: Conv(16→120), 输出 120x1x1(接近 flatten)
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
            nn.BatchNorm2d(120),
            nn.ReLU(inplace=True)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),  # [batch, 120]
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x


if __name__ == "__main__":
    model = LeNet5()

    model.eval()
    # 示例输入:MNIST 的图像大小 [1, 1, 28, 28]
    example_input = torch.randn(1, 1, 28, 28)
    # Tracing
    traced_model = torch.jit.trace(model, example_input)

    # 保存
    traced_model.save("traced_lenet5.pt")
    print("✅ traced_lenet5.pt 已成功保存!")

二、保存用于导入matlab的model

在上面的代码中,我们有几行是产生trace model的,即

在这里插入图片描述

torch.jit.trace() 是 PyTorch 的一种 静态图(Static Graph)转换方法,它会:

  • 运行一次前向传播(forward),记录下所有的张量操作;
  • 然后构建一个不可变的计算图(graph),这个图就是所谓的 trace model

保存这个model后,我们就得到了traced_lenet5.pt这个文件。

三、导入matlab

导入matlab可以通过APPS里的Deep Network Designer,如下图

在这里插入图片描述

然后通过From PyTorch这个地方,导入刚才保存的网络结构

在这里插入图片描述

点开From PyTorch后, 我们可以复制刚才保存的traced_lenet5.pt这个文件的绝对路径用于导入,如下图

在这里插入图片描述

然后,import就会有,如下结果

在这里插入图片描述

然后,点击红色方框那部分,进行一下输入尺寸的修改

在这里插入图片描述

导入的这个网络框架,我们还要在末尾段加入softmax层,这个层在原pytorch框架里没写

在这里插入图片描述

这样,我们就完成了LeNet5从Pytorch里导入到matlab了。接着我们可以通过Analyze按钮分析这个网络,如下图

在这里插入图片描述

没有问题后,我们就可以Export这个网络到工作区了,输出的网络自动命名为net_1。

在这里插入图片描述

四、用matlab训练这个导入的网络

训练的代码如下

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像
    LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)

% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels


%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");

% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  
    MaxEpochs = 4, ...  
    ValidationData = imdsValidation, ... 
    ValidationFrequency = 30, ...  
    Plots = "training-progress", ...  
    Metrics = "accuracy", ...  
    Verbose = false); 



% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);

%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");

%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);

% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);


% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);

% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);

% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)

% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9
    nexttile  % 在下一个网格位置准备绘图
    img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像
    imshow(img)  % 显示图像
    title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

上面用到的数据集是0-9的数字图片,如下图

在这里插入图片描述

训练的详细信息如下

在这里插入图片描述

预测结果显示

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到