There is一种通过图形编辑器将常量转换回 TensorFlow 中可训练变量的方法。但是,您需要指定要转换的节点,因为我不确定是否有办法以可靠的方式自动检测这一点。
步骤如下:
第 1 步:加载冻结图
我们加载我们的.pb
文件到图形对象中。
import tensorflow as tf
# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
tf_graph = load_pb('frozen_graph.pb')
步骤2:找到需要转换的常量
以下是列出图中节点名称的两种方法:
- Use 这个脚本 https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a打印它们
print([n.name for n in tf_graph.as_graph_def().node])
您想要转换的节点可能被命名为“Const”。可以肯定的是,将图表加载到Netron https://github.com/lutzroeder/netron查看哪些张量正在存储可训练权重。通常,可以安全地假设所有 const 节点都曾经是变量。
识别出这些节点后,让我们将它们的名称存储到列表中:
to_convert = [...] # names of tensors to convert
步骤 3:将常量转换为变量
运行此代码以转换您指定的常量。它本质上为每个常量创建相应的变量,并使用 GraphEditor 从图表中取消常量,并挂上变量。
import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge
const_var_name_pairs = []
with tf_graph.as_default() as g:
for name in to_convert:
tensor = g.get_tensor_by_name('{}:0'.format(name))
with tf.Session() as sess:
tensor_as_numpy_array = sess.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
# Create TensorFlow variable initialized by values of original const.
var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \
initializer=tf.constant_initializer(tensor_as_numpy_array))
# We want to keep track of our variables names for later.
const_var_name_pairs.append((name, var_name))
# At this point, we added a bunch of tf.Variables to the graph, but they're
# not connected to anything.
# The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
# the outputs of our newly created Variables.
for const_name, var_name in const_var_name_pairs:
const_op = g.get_operation_by_name(const_name)
var_reader_op = g.get_operation_by_name(var_name + '/read')
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
第 4 步:将结果另存为.ckpt
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = tf.train.Saver().save(sess, 'model.ckpt')
print("Model saved in path: %s" % save_path)
还有中提琴!您应该在这一点上完成:)我自己能够完成这项工作,并验证了模型权重已保留 - 唯一的区别是该图现在是可训练的。如果有任何问题,请告诉我。