模型加载pytorch版本不匹配的解决思路

发布于:2024-09-17 ⋅ 阅读:(365) ⋅ 点赞:(0)

模型部署总是会遇到pytorch版本推理与训练不匹配的问题,一般报错:

AttributeError: Can't get attribute '_rebuild_parameter_v2' on <module 'torch._utils' from '/usr/local/python3.9.0/lib/python3.9/site-packages/torch/_utils.py'>

提示pytorch 中_utils.py没有实现这个方法’_rebuild_parameter_v2’ ,这就表明新的pytorch增加了一些方法,而旧的pytorch没有实现。为了避免环境升级各种乱七八糟的事情,那我们就手动实现它,在仅实现这个function的情况下实现pytorch的伪升级。

(1)首先看pytorch的github

发现’_rebuild_parameter_v2’这个方法是新添加的
在这里插入图片描述
源码路径实现在这:
在这里插入图片描述

(2)那我们就把这段代码抄过来

攒成下面一块:

import torch._utils
try:
    torch._utils._rebuild_parameter_v2
except AttributeError:
    def _set_obj_state(obj, state):
        if isinstance(state, tuple):
            if not len(state) == 2:
                raise RuntimeError(f"Invalid serialized state: {state}")
            dict_state = state[0]
            slots_state = state[1]
        else:
            dict_state = state
            slots_state = None

        for k, v in dict_state.items():
            setattr(obj, k, v)

        if slots_state:
            for k, v in slots_state.items():
                setattr(obj, k, v)
        return obj
    def _rebuild_parameter_v2(data, requires_grad, backward_hooks, state):
        param = torch.nn.Parameter(data, requires_grad)
        param._backward_hooks = backward_hooks
        param = _set_obj_state(param, state)
        return param
    torch._utils._rebuild_parameter_v2 = _rebuild_parameter_v2

这样,每次加载模型的时候,把上面那段代码拷贝上去,就能执行,不会报错了
在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到