爪哇importGraphDef()
函数仅导入计算图(由tf.train.write_graph
在你的Python代码中),它没有加载经过训练的变量的值(存储在检查点中),这就是为什么你会收到一个错误,抱怨未初始化的变量。
The TensorFlow SavedModel 格式 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/saved_model另一方面包括有关模型的所有信息(图形、检查点状态、其他元数据)并在您想要使用的 Java 中使用SavedModelBundle.load https://www.tensorflow.org/versions/r1.1/api_docs/java/reference/org/tensorflow/SavedModelBundle创建使用经过训练的变量值初始化的会话。
要从 Python 导出这种格式的模型,您可能需要查看相关问题将重新训练的 inception SavedModel 部署到 google cloud ml 引擎 https://stackoverflow.com/questions/43001719/deploy-retrained-inception-savedmodel-to-google-cloud-ml-engine/43002175
对于您的情况,这应该类似于 Python 中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
并通过调用它save_model(session, x, yhat)
然后在 Java 中使用以下命令加载模型:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
希望有帮助。