python序列化、反序列化函数的参数,用于问题复现

发布于:2024-05-02 ⋅ 阅读:(25) ⋅ 点赞:(0)

python序列化、反序列化函数的参数,用于问题复现


一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套
Tensor,ndarray,int,float,list,tuple

一.代码

# -*- coding: utf-8 -*-
'''
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套:
Tensor,ndarray,int,float,list,tuple
'''

import torch
import numpy as np
import pickle
from dataclasses import dataclass
from typing import Any

var_save_path="vars.pkl"

@dataclass
class DataDescriptor:
    class_name: Any
    shape: Any
    value: Any
    dtype: Any
    def data(self):
        if self.class_name=="Tensor":
            return torch.zeros(self.shape,dtype=self.dtype)
        elif self.class_name in ["int","float"]:
            return self.value
        elif self.class_name in ["ndarray"]:
            return np.zeros(self.shape,dtype=self.dtype)
        elif self.class_name in ["list","tuple"]:
            output=[]
            for t in self.value:
                output.append(t.data())
            return output
        else:
            raise f"Unkown:{self.class_name}"
    def __repr__(self) -> str:
        output_str=[]
        if self.shape:
            output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))
        if self.value:
            if self.class_name in ["list","tuple"]:
                for t in self.value:
                    output_str.append(str(t))
            else:
                output_str.append(str(self.value))
        if self.dtype and self.class_name in ["Tensor","ndarray"]:
            output_str.append(str(self.dtype))
        return "{}({})".format(self.class_name,"-".join(output_str))

class InputDescriptor:
    def __init__(self) -> None:
        self.input_vars=[]
        self.input_kwargs={}
    def serialize(self,path):
        with open(path,"wb") as f:
            pickle.dump(self,f)

    @classmethod
    def deserialize(cls,path):
        with open(path,"rb") as f:
            return pickle.load(f)
    def data(self):
        input_vars=[]
        input_kwargs={}
        for var in self.input_vars:
            input_vars.append(var.data())
        for k,v in self.input_kwargs.items():
            input_kwargs[k]=v.data()
        return input_vars,input_kwargs
    def _save_var(self,v):
        class_name=v.__class__.__name__
        if class_name=="Tensor":
            return DataDescriptor(class_name,list(v.shape),None,v.dtype)
        elif class_name in ["int","float"]:
            return DataDescriptor(class_name,None,v,type(v))
        elif class_name in ["ndarray"]:
            return DataDescriptor(class_name,list(v.shape),None,v.dtype)
        elif class_name in ["list","tuple"]:
            output=[]
            for t in v:
                output.append(self._save_var(t))
            return DataDescriptor(class_name,None,output,None)
        else:
            raise f"Unkown:{class_name}"
    def save_vars(self,*args,**kwargs):
        for arg in args:
            self.input_vars.append(self._save_var(arg))

        for k,v in kwargs.items():
            self.input_kwargs[k]=self._save_var(v)

    def __repr__(self) -> str:
        return str(self.input_vars) + "#" + str(self.input_kwargs)     
def do_something(*args,**kwargs):
    '''某个函数'''

    print(args,kwargs)  

    # 保存输入参数,序列化保存到文件
    desc=InputDescriptor()
    desc.save_vars(*args,**kwargs)
    desc.serialize(var_save_path)

    return True

def main():

    # 1.准备输入参数
    input_vars=[]
    input_kwargs={}

    input_vars.append(torch.zeros((1,23,128),dtype=torch.float32))
    input_vars.append(1)
    input_vars.append(2.0)
    input_vars.append((1,2,3))
    input_vars.append(np.zeros((2,3,4),dtype=np.float32))
    input_vars.append([ 
                        torch.zeros((1,23,128),dtype=torch.float32),
                        torch.zeros((1,23,128),dtype=torch.float32)
                      ])
    input_vars.append([np.zeros((2,3,4),dtype=np.float32)])

    input_kwargs["a"]=1
    input_kwargs["b"]=4
    input_kwargs["c"]=input_vars
    input_kwargs["d"]=torch.zeros((1,23,128),dtype=torch.float32)

    # 2.调用用户函数,在函数中对参数序列化
    do_something(*input_vars,**input_kwargs)
   
    print("#"*32+" reproduce "+"#"*32)

    # 3.加载序列化文件
    desc=InputDescriptor.deserialize(var_save_path)

    # 4.生成参数
    _input_vars,_input_kwargs=desc.data()

    # 5.再次调用用户函数
    do_something(*_input_vars,**_input_kwargs)

    # 6.打印参数
    print(desc)

if __name__ == "__main__":  
    main()

二.输出

[Tensor(shape:(1,23,128)-torch.float32), int(1), float(2.0), tuple(int(1)-int(2)-int(3)), 
ndarray(shape:(2,3,4)-float32), list(
Tensor(shape:(1,23,128)-torch.float32)-Tensor(shape:(1,23,128)-torch.float32)), 
list(ndarray(shape:(2,3,4)-float32))]
#{'a': int(1), 'b': int(4),
'c': list(Tensor(shape:(1,23,128)-torch.float32)
-int(1)-float(2.0)-tuple(int(1)-int(2)-int(3))
-ndarray(shape:(2,3,4)-float32)-list(Tensor(shape:(1,23,128)
-torch.float32)-Tensor(shape:(1,23,128)-torch.float32))
-list(ndarray(shape:(2,3,4)-float32))), 'd': Tensor(shape:(1,23,128)-torch.float32)}

网站公告

今日签到

点亮在社区的每一天
去签到