source_model 路径下 存在 以下几个checkpoint
model_checkpoint_path: "model.ckpt-457157707"
all_model_checkpoint_paths: "model.ckpt-456023526" ,all_model_checkpoint_paths: "model.ckpt-456332667" ,all_model_checkpoint_paths: "model.ckpt-456332668",all_model_checkpoint_paths: "model.ckpt-456832684" ,all_model_checkpoint_paths: "model.ckpt-457157707"
现在将这些ckpt的参数进行平均 合并成一个model.ckpt-457157708
import tensorflow as tf
import numpy as np
# 获取所有的checkpoint文件
ckpt_files = ["model.ckpt-456023526", "model.ckpt-456332667", "model.ckpt-456332668", "model.ckpt-456832684", "model.ckpt-457157707"]
ckpt_files = [os.path.join("source_model", ckpt_file) for ckpt_file in ckpt_files]
# 用于存储所有模型的参数
all_model_vars = {}
for ckpt_file in ckpt_files:
reader = tf.train.NewCheckpointReader(ckpt_file)
model_vars = reader.get_variable_to_shape_map()
for var in model_vars:
if var not in all_model_vars:
all_model_vars[var] = []
all_model_vars[var].append(reader.get_tensor(var))
# 计算每个参数的平均值
average_vars = {var: np.mean(values, axis=0) for var, values in all_model_vars.items()}
# 创建一个新的checkpoint文件,并将平均后的参数保存到新的.data文件中
with tf.Session() as sess:
for var_name, var_value in average_vars.items():
var = tf.get_variable(var_name, initializer=var_value)
sess.run(var.initializer)
saver = tf.train.Saver()
saver.save(sess, "source_model/model.ckpt-457157708")