深度学习之pth转换为onnx时修改模型定义‌

发布于:2024-12-07 ⋅ 阅读:(188) ⋅ 点赞:(0)

概述

在将PyTorch模型(.pth文件)转换为ONNX格式时,通常的转换过程是通过torch.onnx.export函数来实现的。这个过程主要是将PyTorch模型的计算图导出为ONNX格式,以便在其他框架或环境中使用。

在转换过程中,你通常不能直接在原有的PyTorch模型前后“添加函数”,因为ONNX导出的是静态计算图,它表示的是模型在某一时刻的结构和参数,而不是动态的执行过程。不过,你可以通过‌修改模型定义‌的方式来实现类似的功能。

在导出模型之前,你可以修改模型的定义,将你想要添加的功能集成到模型本身中。例如,如果你想要在模型的前向传播过程中添加某些预处理或后处理步骤,你可以直接将这些步骤写入模型类的forward方法中。

实现步骤

  1. 定义新模型类
  2. 将原模型添加为新模型的成员
  3. 在新模型的forward中,在原有模型之前或之后添加新的层
  4. 初始化新模型
  5. 加载原有模型参数
  6. 导出onnx

python代码

from model import *
from utils import *
from data import *
import cv2


# 这是你修改后的模型定义,集成了额外功能
class ModifiedModel(nn.Module):
    def __init__(self):
        super(ModifiedModel, self).__init__()
        num_classes = 3
        self.original_model = UNet(3, num_classes)
        # 新增的层或修改后的层
        # self.new_layer = torch.argmax()

    def forward(self, x):
        # 在原始模型前添加预处理(如果需要)
        x = self.original_model(x)
        # 在原始模型后添加后处理或新增层的逻辑
        # x = self.new_layer(x)
        x = torch.argmax(x[0], dim=0).unsqueeze(0) * 255
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
weight_path = 'params/unet_CXR.pth'

pretrained_dict = torch.load(weight_path)
# 初始化修改后的模型,并加载原始模型的参数
modified_model = ModifiedModel()
modified_model.to(device)
# 假设我们只关心原始模型的参数,可以直接将其赋值给修改后的模型中的对应部分
modified_model.original_model.load_state_dict(pretrained_dict)
modified_model.eval()

img_data = torch.randn(1, 3, 256, 256)
img_data = img_data.to(device)
out_data = modified_model(img_data)

out_data = out_data.cpu().detach().numpy()
out_data = np.array(out_data, dtype='uint8')

cv2.imshow('out', out_data[0, :, :])
cv2.waitKey(0)

# 将模型导出为 ONNX 格式
is_dynamic_axes = False
if is_dynamic_axes:
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(modified_model,
                      img_data,
                      r"params/net_model_modify.onnx",
                      opset_version=11,
                      input_names=[input_name], 
                      output_names=[output_name], 
                      dynamic_axes={
                          input_name: {0: 'batch_size', 2: 'in_width', 3: 'int_height'},
                          output_name: {0: 'batch_size', 2: 'out_width', 3: 'out_height'}},
                      verbose=True)
else:
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(modified_model,
                      img_data,
                      r"params/net_model_modify.onnx",
                      opset_version=11,
                      input_names=[input_name], 
                      output_names=[output_name],  
                      verbose=True)

原有模型和修改后的模型onnx计算图如下:
在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到