我在使用 java tensorflow API 时遇到问题。我使用 python tensorflow API 运行训练,生成文件 output_graph.pb 和 output_labels.txt。现在,出于某种原因,我想使用这些文件作为 java tensorflow API 中 LabelImage 模块的输入。我认为一切都会很好,因为该模块只需要一个 .pb 和一个 .txt。然而,当我运行该模块时,我收到此错误:
2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)
如果您能帮我找出问题所在,我将不胜感激。此外,我想问你是否有一种方法可以从 java tensorflow API 运行训练,因为这会让事情变得更容易。
更准确地说:
事实上,我并没有使用自己编写的代码,至少在相关步骤中是这样。我所做的就是用这个模块进行训练,https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py,向其提供包含图像的目录,这些图像根据其描述划分在子目录中。特别是,我认为这些是生成输出的行:
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
然后,我将输出(在同一个 graph.pb 和一些 labels.txt 上)作为该 java 模块的输入:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java,替换默认输入。我得到的错误是上面报告的错误。