pytorch 的pth格式模型转onnx格式模型 - python 实现

发布于:2024-11-29 ⋅ 阅读:(26) ⋅ 点赞:(0)

pytorch 的pth格式模型转onnx格式模型 - python 实现

#-*-coding:utf-8-*-
# date:2021-10-5
# Author: DataBall - XIAN
# function: pytorch model 2 onnx

import os
import argparse
import torch
import torch.nn as nn
import numpy as np

from network.resnet import resnet18,resnet50

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description=' Project handpose x')
    parser.add_argument('--model_path', type=str, default = r'ckpt\resnet_18_epoch-275-x96.pth',
        help = 'model_path') # 模型路径
    parser.add_argument('--model', type=str, default = 'resnet_18',
        help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2
            shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0''') # 模型类型
    parser.add_argument('--GPUS', type=str, default = '0',
        help = 'GPUS') # GPU选择
    parser.add_argument('--test_path', type=str, default = './image/',
        help = 'test_path') # 测试图片路径
    parser.add_argument('--img_size', type=tuple , default = (96,96),
        help = 'img_size') # 输入模型图片尺寸

    print('\n/******************* {} ******************/\n'.format(parser.description))
    #--------------------------------------------------------------------------
    ops = parser.parse_args()# 解析添加参数
    #--------------------------------------------------------------------------
    print('----------------------------------')

    unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
    for key in unparsed.keys():
        print('{} : {}'.format(key,unparsed[key]))

    #---------------------------------------------------------------------------
    os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS

    test_path =  ops.test_path # 测试图片文件夹路径

    #---------------------------------------------------------------- 构建模型
    print('use model : %s'%(ops.model))

    if ops.model == 'resnet_50':
        model_ = resnet50(img_size=ops.img_size[0])
    elif ops.model == 'resnet_18':
        model_ = resnet18(img_size=ops.img_size[0])
    

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    model_ = model_.to(device)
    model_.eval() # 设置为前向推断模式

    # 加载测试模型
    if os.access(ops.model_path,os.F_OK):# checkpoint
        chkpt = torch.load(ops.model_path, map_location=device)
        model_.load_state_dict(chkpt)
        print('load test model : {}'.format(ops.model_path))

    input_size = ops.img_size[0]
    batch_size = 1  #批处理大小
    input_shape = (3, input_size,input_size)   #输入数据,改成自己的输入shape
    print("input_size : ",input_size)

    x = torch.randn(batch_size, *input_shape)   # 生成张量
    x = x.to(device)
    export_onnx_file = "{}_size-{}.onnx".format(ops.model,input_size)		# 目的ONNX文件名
    torch.onnx.export(model_,
                        x,
                        export_onnx_file,
                        opset_version=9,
                        do_constant_folding=True,	# 是否执行常量折叠优化
                        input_names=["input"],	# 输入名
                        output_names=["output2d"],	# 输出名
                        #dynamic_axes={"input":{0:"batch_size"},  # 批处理变量
                        #                "output":{0:"batch_size"}}
                        )

脚本对应输出结果如下:


/*******************  Project handpose x ******************/

----------------------------------
model_path : ckpt\resnet_18_epoch-275-x96.pth
model : resnet_18
GPUS : 0
test_path : ./image/
img_size : (96, 96)
use model : resnet_18
load test model : ckpt\resnet_18_epoch-275-x96.pth
input_size :  96

 ​​​

助力快速掌握数据集的信息和使用方式。

数据可以如此美好!