最快的解决方案:
tensor2tensor
有一个模块utils
带脚本avg_checkpoints.py https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py将平均权重保存在新的检查点中。假设您有一个要计算平均值的检查点列表。您有 2 个使用选项:
-
从命令行
TRAIN_DIR=path_to_your_model_folder
FNC_PATH=path_to_tensor2tensor+'/utils/avg.checkpoints.py'
CKPTS=model.ckpt-10000,model.ckpt-20000,model.ckpt-100000
python3 $FNC_PATH --prefix=$TRAIN_DIR --checkpoints=$CKPTS \
--output_path="${TRAIN_DIR}averaged.ckpt"
-
从您自己的代码(使用os.system
):
import os
os.system(
"python3 "+FNC_DIR+" --prefix="+TRAIN_DIR+" --checkpoints="+CKPTS+
" --output_path="+TRAIN_DIR+"averaged.ckpt"
)
作为指定检查点列表并使用--checkpoints
参数,你可以使用--num_checkpoints=10
计算最后 10 个检查点的平均值。
如果您不想依赖tensor2tensor
:
这是一个不依赖的代码片段tensor2tensor
,但仍然可以平均检查点数量可变(与特德的回答相反)。认为steps
是应该合并的检查点列表(例如[10000, 20000, 30000, 40000]
).
Then:
# Restore all sessions and save the weight matrices
values = []
for step in steps:
tf.reset_default_graph()
path = model_path+'/model.ckpt-'+str(step)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(path+'.meta')
saver.restore(sess, path)
values.append(sess.run(tf.all_variables()))
# Average weights
variables = tf.all_variables()
all_assign = []
for ind, var in enumerate(variables):
weights = np.concatenate(
[np.expand_dims(w[ind],axis=0) for w in values],
axis=0
)
all_assign.append(tf.assign(var, np.mean(weights, axis=0))
然后您可以按照您的喜好继续操作,例如保存平均检查点:
# Now save the new values into a separate checkpoint
with tf.Session() as sess_test:
sess_test.run(all_assign)
saver = tf.train.Saver()
saver.save(sess_test, model_path+'/average_'+str(num_checkpoints))