👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!
📁 收藏专栏即可第一时间获取最新推送🔔。
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。
模型导出
本文介绍如何将深度学习模型导出为不同的部署格式,包括ONNX、TorchScript等,并对各种格式的优缺点和最佳实践进行总结,帮助你高效完成模型部署准备。
1. 导出格式对比
格式 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
ONNX | - 跨平台跨框架 - 生态丰富 - 标准统一 - 广泛支持 |
- 可能存在算子兼容问题 - 部分高级特性支持有限 |
- 跨平台部署 - 使用标准推理引擎 - 需要广泛兼容性 |
TorchScript | - 与PyTorch无缝集成 - 支持动态图结构 - 调试方便 - 性能优化 |
- 仅限PyTorch生态 - 文件体积较大 |
- PyTorch生产环境 - 需要动态特性 - 性能要求高 |
TensorRT | - 极致优化性能 - 支持GPU加速 - 低延迟推理 |
- 仅支持NVIDIA GPU - 配置复杂 |
- 高性能推理场景 - 实时应用 - 边缘计算 |
TensorFlow SavedModel | - TensorFlow生态完整支持 - 部署便捷 |
- 跨框架兼容性差 | - TensorFlow生产环境 |
2. ONNX格式导出
2.1 基本导出
ONNX格式适用于跨平台部署,支持多种推理引擎(如ONNXRuntime、TensorRT、OpenVINO等)。
import torch
import torch.onnx
def export_to_onnx(model, input_shape, save_path):
# 设置模型为评估模式
model.eval()
# 创建示例输入
dummy_input = torch.randn(input_shape)
# 导出模型
torch.onnx.export(
model, # 要导出的模型
dummy_input, # 模型输入
save_path, # 保存路径
export_params=True, # 导出模型参数
opset_version=11, # ONNX算子集版本
do_constant_folding=True, # 常量折叠优化
input_names=['input'], # 输入名称
output_names=['output'], # 输出名称
dynamic_axes={
'input': {0: 'batch_size'}, # 动态批次大小
'output': {0: 'batch_size'}
}
)
print(f"Model exported to {save_path}")
# 使用示例
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
export_to_onnx(model, (1, 3, 224, 224), 'model.onnx')
2.2 验证导出模型
导出后必须进行全面验证,包括结构检查和数值对比:
- 结构验证
import onnx
import onnxruntime
import numpy as np
def verify_onnx_structure(onnx_path):
# 加载并检查模型结构
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
# 打印模型信息
print("模型输入:")
for input in onnx_model.graph.input:
print(f"- {input.name}: {input.type.tensor_type.shape}")
print("\n模型输出:")
for output in onnx_model.graph.output:
print(f"- {output.name}: {output.type.tensor_type.shape}")
- 数值精度对比
def compare_outputs(model, onnx_path, input_data):
# PyTorch结果
model.eval()
with torch.no_grad():
torch_output = model(torch.from_numpy(input_data))
# ONNX结果
ort_output = verify_onnx_model(onnx_path, input_data)
# 比较差异
diff = np.abs(torch_output.numpy() - ort_output).max()
print(f"最大误差: {diff}")
return diff < 1e-5
- 验证 ONNX 模型
import onnx
import onnxruntime
import numpy as np
def verify_onnx_model(onnx_path, input_data):
# 加载ONNX模型
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
# 创建推理会话
ort_session = onnxruntime.InferenceSession(onnx_path)
# 准备输入数据
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
# 运行推理
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs[0]
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = verify_onnx_model('model.onnx', input_data)
2.3 ONNX模型优化
使用ONNX Runtime提供的优化工具进一步提升性能:
import onnxruntime as ort
from onnxruntime.transformers import optimizer
def optimize_onnx_model(onnx_path, optimized_path):
# 创建优化器配置
opt_options = optimizer.OptimizationConfig(
optimization_level=99, # 最高优化级别
enable_gelu_approximation=True,
enable_layer_norm_optimization=True,
enable_attention_fusion=True
)
# 优化模型
optimized_model = optimizer.optimize_model(
onnx_path,
'cpu', # 或 'gpu'
opt_options
)
# 保存优化后的模型
optimized_model.save_model_to_file(optimized_path)
print(f"优化后的模型已保存至 {optimized_path}")
optimizer.optimize_model()
第二个参数是优化目标设备,支持 ‘cpu’ 或 ‘gpu’。- 优化目标设备:指定模型优化时的目标硬件平台。例如:
- ‘cpu’:针对 CPU 进行优化(如调整算子、量化参数等)。
- ‘gpu’:针对 GPU 进行优化(如使用 CUDA 内核、张量核心等)。
*运行时设备:优化后的模型可以在其他设备上运行,但性能可能受影响。例如: - 针对 CPU 优化的模型可以在 GPU 上运行,但可能无法充分利用 GPU 特性。
- 针对 GPU 优化的模型在 CPU 上运行可能会报错或性能下降。
建议保持优化目标与运行设备一致以获得最佳性能。
- 优化目标设备:指定模型优化时的目标硬件平台。例如:
3. TorchScript格式导出
3.1 trace导出
适用于前向计算图结构固定的模型。
import torch
def export_torchscript_trace(model, input_shape, save_path):
model.eval()
example_input = torch.randn(input_shape)
# 使用跟踪法导出
traced_model = torch.jit.trace(model, example_input)
traced_model.save(save_path)
print(f"Traced model exported to {save_path}")
return traced_model
3.2 script导出
适用于包含条件分支、循环等动态结构的模型。
import torch
import torch.nn as nn
@torch.jit.script
class ScriptableModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
def export_torchscript_script(model, save_path):
scripted_model = torch.jit.script(model)
scripted_model.save(save_path)
print(f"Scripted model exported to {save_path}")
return scripted_model
3.3 TorchScript模型验证
验证TorchScript模型的正确性:
def verify_torchscript_model(original_model, ts_model_path, input_data):
# 原始模型输出
original_model.eval()
with torch.no_grad():
original_output = original_model(input_data)
# 加载TorchScript模型
ts_model = torch.jit.load(ts_model_path)
ts_model.eval()
# TorchScript模型输出
with torch.no_grad():
ts_output = ts_model(input_data)
# 比较差异
diff = torch.abs(original_output - ts_output).max().item()
print(f"最大误差: {diff}")
return diff < 1e-5
4. 自定义算子处理
4.1 ONNX自定义算子
如需导出自定义算子,可通过ONNX扩展机制实现。
from onnx import helper
def create_custom_op():
# 定义自定义算子
custom_op = helper.make_node(
'CustomOp', # 算子名称
inputs=['input'], # 输入
outputs=['output'], # 输出
domain='custom.domain'
)
return custom_op
def register_custom_op():
# 注册自定义算子
from onnxruntime.capi import _pybind_state as C
C.register_custom_op('CustomOp', 'custom.domain')
4.2 TorchScript自定义算子
可通过C++扩展自定义TorchScript算子。
from torch.utils.cpp_extension import load
# 编译自定义C++算子
custom_op = load(
name="custom_op",
sources=["custom_op.cpp"],
verbose=True
)
# 在模型中使用自定义算子
class ModelWithCustomOp(nn.Module):
def forward(self, x):
return custom_op.forward(x)
4.3 自定义算子示例
下面是一个完整的自定义算子实现示例:
// custom_op.cpp
#include <torch/extension.h>
torch::Tensor custom_forward(torch::Tensor input) {
return input.sigmoid().mul(2.0);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &custom_forward, "Custom forward function");
}
# 在Python中使用
import torch
from torch.utils.cpp_extension import load
# 编译自定义算子
custom_op = load(
name="custom_op",
sources=["custom_op.cpp"],
verbose=True
)
# 测试自定义算子
input_tensor = torch.randn(2, 3)
output = custom_op.forward(input_tensor)
print(output)
5. 模型部署示例
5.1 ONNXRuntime部署
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
def preprocess_image(image_path, input_shape):
# 图像预处理
transform = transforms.Compose([
transforms.Resize((input_shape[2], input_shape[3])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).numpy()
return image_tensor
def onnx_inference(onnx_path, image_path, input_shape=(1, 3, 224, 224)):
# 加载ONNX模型
session = ort.InferenceSession(onnx_path)
# 预处理图像
input_data = preprocess_image(image_path, input_shape)
# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 执行推理
result = session.run([output_name], {input_name: input_data})
return result[0]
5.2 TorchScript部署
import torch
from PIL import Image
import torchvision.transforms as transforms
def torchscript_inference(model_path, image_path):
# 加载TorchScript模型
model = torch.jit.load(model_path)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载并处理图像
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0)
# 执行推理
with torch.no_grad():
output = model(input_tensor)
return output
6. 常见问题与解决方案
6.1 ONNX导出失败
问题: 导出ONNX时出现算子不支持错误
解决方案:
# 尝试使用更高版本的opset
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=13)
# 或替换不支持的操作
class ModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# 替换不支持的操作为等效操作
return self.model(x)
6.2 TorchScript跟踪失败
问题: 动态控制流导致trace失败
解决方案:
# 使用script而非trace
scripted_model = torch.jit.script(model)
# 或修改模型结构避免动态控制流
class TraceFriendlyModel(nn.Module):
def __init__(self, original_model):
super().__init__()
self.model = original_model
def forward(self, x):
# 移除动态控制流
return self.model.forward_fixed(x)
6.3 推理性能问题
问题: 导出模型推理速度慢
解决方案:
# 1. 使用量化
from torch.quantization import quantize_dynamic
quantized_model = quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)
# 2. 使用TensorRT优化ONNX
import tensorrt as trt
# TensorRT优化代码...
# 3. 使用ONNX Runtime优化
import onnxruntime as ort
session = ort.InferenceSession("model.onnx", providers=['CUDAExecutionProvider'])
7. 最佳实践
选择合适的导出格式
- ONNX:适合跨平台、跨框架部署,兼容性强
- TorchScript:适合PyTorch生态内部署,支持灵活性高
- 根据目标平台和性能需求选择
优化导出模型
- 使用合适的opset版本(建议11及以上)
- 启用常量折叠等优化选项
- 导出后务必验证模型正确性
- 考虑使用量化和剪枝优化模型大小
处理动态输入
- 设置动态维度(如batch_size)
- 测试不同输入大小,确保模型鲁棒性
- 记录支持的输入范围和约束
文档和版本控制
- 记录导出配置和依赖版本
- 保存模型元数据(如输入输出规格)
- 对模型文件进行版本化管理
- 维护模型卡片(Model Card)记录关键信息
调试技巧
- 使用ONNX Graph Viewer等可视化工具分析模型结构
- 使用Netron查看计算图和参数分布
- 比较原始与导出模型输出,检查数值精度差异
- 遇到兼容性问题时查阅官方文档和社区经验
8. 参考资源
📌 感谢阅读!若文章对你有用,别吝啬互动~
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!