【神经网络】10 - 网络模型的保存与读取

发布于:2024-05-09 ⋅ 阅读:(27) ⋅ 点赞:(0)

10 - 网络模型的保存与读取

概念

在训练模型时,训练完成后我们需要保存模型,使用模型预测时我们需要读取模型。

为什么要保存和读取模型?

  1. 节省训练时间:训练深度学习模型通常需要大量的计算资源和时间,特别是当我们处理大型数据集并训练复杂的网络结构时。保存训练好的模型权重就意味着我们可以随时记住当前的学习成果,而无需从头开始训练。
  2. 持续优化:在一些复杂的任务中,我们可能需要多次调整模型的参数或者结构,然后再重新训练。保存模型让我们可以在之前的训练结果的基础上继续优化,而不是每次都重新开始。
  3. 模型部署:当模型训练完毕并且表现良好,我们可能需要将模型部署到不同的环境中,例如服务器或者移动设备上。在这些场景下,我们需要保存模型,并加载到目标环境中。
  4. 模型复用:预训练的模型(例如在ImageNet数据集上训练的模型)可以被用作新任务的起点,这被称作迁移学习。通过加载预训练模型的权重,我们可以利用已经学到的特征,更快速并且更有效地完成新任务的训练。

示例

方式1

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights="DEFAULT")
# 保存方式1
# 不仅保存了模型,也保存了参数
torch.save(vgg16, "saved_model/vgg16_method1.pth")
import torch

# 读取方式1(对应保存方式1)
model = torch.load("saved_model/vgg16_method1.pth")
print(model)

img

在使用方法1进行保存时,如果你自己写了一个模型类,使用方式1保存模型类的对象,在加载保存后的文件时需要原始模型的类的定义。

方式2

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights="DEFAULT")

# 保存方式2
# 只保存模型参数(官方推荐)
torch.save(vgg16.state_dict(), "saved_model/vgg16_method2.pth")
import torch

# 读取方式2(对应保存方式2)
model2 = torch.load("saved_model/vgg16_method2.pth")
print(model2)

# 读取方式2(对应保存方式2)
# 加载参数
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("saved_model/vgg16_method2.pth"))
print(vgg16)

img


网站公告

今日签到

点亮在社区的每一天
去签到