一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套
Tensor,ndarray,int,float,list,tuple
一.代码
# -*- coding: utf-8 -*-
'''
一.背景
1.想dump出pytorch模型所调用基础算子的参数
2.对于Tensor,ndarray。只保存type,shape,不存值
3.之后可通过以上保存的信息,生成算子的参数,运行单算子
二.目前支持以下类型及嵌套:
Tensor,ndarray,int,float,list,tuple
'''
import torch
import numpy as np
import pickle
from dataclasses import dataclass
from typing import Any
var_save_path="vars.pkl"
@dataclass
class DataDescriptor:
class_name: Any
shape: Any
value: Any
dtype: Any
def data(self):
if self.class_name=="Tensor":
return torch.zeros(self.shape,dtype=self.dtype)
elif self.class_name in ["int","float"]:
return self.value
elif self.class_name in ["ndarray"]:
return np.zeros(self.shape,dtype=self.dtype)
elif self.class_name in ["list","tuple"]:
output=[]
for t in self.value:
output.append(t.data())
return output
else:
raise f"Unkown:{self.class_name}"
def __repr__(self) -> str:
output_str=[]
if self.shape:
output_str.append("shape:({})".format(",".join([str(x) for x in self.shape])))
if self.value:
if self.class_name in ["list","tuple"]:
for t in self.value:
output_str.append(str(t))
else:
output_str.append(str(self.value))
if self.dtype and self.class_name in ["Tensor","ndarray"]:
output_str.append(str(self.dtype))
return "{}({})".format(self.class_name,"-".join(output_str))
class InputDescriptor:
def __init__(self) -> None:
self.input_vars=[]
self.input_kwargs={}
def serialize(self,path):
with open(path,"wb") as f:
pickle.dump(self,f)
@classmethod
def deserialize(cls,path):
with open(path,"rb") as f:
return pickle.load(f)
def data(self):
input_vars=[]
input_kwargs={}
for var in self.input_vars:
input_vars.append(var.data())
for k,v in self.input_kwargs.items():
input_kwargs[k]=v.data()
return input_vars,input_kwargs
def _save_var(self,v):
class_name=v.__class__.__name__
if class_name=="Tensor":
return DataDescriptor(class_name,list(v.shape),None,v.dtype)
elif class_name in ["int","float"]:
return DataDescriptor(class_name,None,v,type(v))
elif class_name in ["ndarray"]:
return DataDescriptor(class_name,list(v.shape),None,v.dtype)
elif class_name in ["list","tuple"]:
output=[]
for t in v:
output.append(self._save_var(t))
return DataDescriptor(class_name,None,output,None)
else:
raise f"Unkown:{class_name}"
def save_vars(self,*args,**kwargs):
for arg in args:
self.input_vars.append(self._save_var(arg))
for k,v in kwargs.items():
self.input_kwargs[k]=self._save_var(v)
def __repr__(self) -> str:
return str(self.input_vars) + "#" + str(self.input_kwargs)
def do_something(*args,**kwargs):
'''某个函数'''
print(args,kwargs)
# 保存输入参数,序列化保存到文件
desc=InputDescriptor()
desc.save_vars(*args,**kwargs)
desc.serialize(var_save_path)
return True
def main():
# 1.准备输入参数
input_vars=[]
input_kwargs={}
input_vars.append(torch.zeros((1,23,128),dtype=torch.float32))
input_vars.append(1)
input_vars.append(2.0)
input_vars.append((1,2,3))
input_vars.append(np.zeros((2,3,4),dtype=np.float32))
input_vars.append([
torch.zeros((1,23,128),dtype=torch.float32),
torch.zeros((1,23,128),dtype=torch.float32)
])
input_vars.append([np.zeros((2,3,4),dtype=np.float32)])
input_kwargs["a"]=1
input_kwargs["b"]=4
input_kwargs["c"]=input_vars
input_kwargs["d"]=torch.zeros((1,23,128),dtype=torch.float32)
# 2.调用用户函数,在函数中对参数序列化
do_something(*input_vars,**input_kwargs)
print("#"*32+" reproduce "+"#"*32)
# 3.加载序列化文件
desc=InputDescriptor.deserialize(var_save_path)
# 4.生成参数
_input_vars,_input_kwargs=desc.data()
# 5.再次调用用户函数
do_something(*_input_vars,**_input_kwargs)
# 6.打印参数
print(desc)
if __name__ == "__main__":
main()
二.输出
[Tensor(shape:(1,23,128)-torch.float32), int(1), float(2.0), tuple(int(1)-int(2)-int(3)),
ndarray(shape:(2,3,4)-float32), list(
Tensor(shape:(1,23,128)-torch.float32)-Tensor(shape:(1,23,128)-torch.float32)),
list(ndarray(shape:(2,3,4)-float32))]
#{'a': int(1), 'b': int(4),
'c': list(Tensor(shape:(1,23,128)-torch.float32)
-int(1)-float(2.0)-tuple(int(1)-int(2)-int(3))
-ndarray(shape:(2,3,4)-float32)-list(Tensor(shape:(1,23,128)
-torch.float32)-Tensor(shape:(1,23,128)-torch.float32))
-list(ndarray(shape:(2,3,4)-float32))), 'd': Tensor(shape:(1,23,128)-torch.float32)}