多个checkpoint 的参数进行平均

发布于:2023-09-14 ⋅ 阅读:(80) ⋅ 点赞:(0)

source_model 路径下 存在 以下几个checkpoint
model_checkpoint_path: "model.ckpt-457157707"
all_model_checkpoint_paths: "model.ckpt-456023526" ,all_model_checkpoint_paths: "model.ckpt-456332667" ,all_model_checkpoint_paths: "model.ckpt-456332668",all_model_checkpoint_paths: "model.ckpt-456832684" ,all_model_checkpoint_paths: "model.ckpt-457157707"

现在将这些ckpt的参数进行平均 合并成一个model.ckpt-457157708 

import tensorflow as tf
import numpy as np

# 获取所有的checkpoint文件
ckpt_files = ["model.ckpt-456023526", "model.ckpt-456332667", "model.ckpt-456332668", "model.ckpt-456832684", "model.ckpt-457157707"]
ckpt_files = [os.path.join("source_model", ckpt_file) for ckpt_file in ckpt_files]

# 用于存储所有模型的参数
all_model_vars = {}

for ckpt_file in ckpt_files:
    reader = tf.train.NewCheckpointReader(ckpt_file)
    model_vars = reader.get_variable_to_shape_map()
    for var in model_vars:
        if var not in all_model_vars:
            all_model_vars[var] = []
        all_model_vars[var].append(reader.get_tensor(var))

# 计算每个参数的平均值
average_vars = {var: np.mean(values, axis=0) for var, values in all_model_vars.items()}

# 创建一个新的checkpoint文件,并将平均后的参数保存到新的.data文件中
with tf.Session() as sess:
    for var_name, var_value in average_vars.items():
        var = tf.get_variable(var_name, initializer=var_value)
        sess.run(var.initializer)

    saver = tf.train.Saver()
    saver.save(sess, "source_model/model.ckpt-457157708")