(注意,我已经解决了我的问题并将代码发布在底部)
我正在使用 TensorFlow,后端处理必须在 Java 中进行。我从 中取出了其中一个模型https://developers.google.com/machine-learning/crash-course https://developers.google.com/machine-learning/crash-course并使用 tf.saved_model.save(my_model,"house_price_median_venue") 保存它(使用 docker 容器)。我复制了模型并将其加载到 Java 中(使用从源代码构建的 2.0 内容,因为我使用的是 Windows)。
我可以加载模型并运行它:
try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
try (Session session = model.session()) {
Session.Runner runner = session.runner();
float[][] in = new float[][]{ {2.1518f} } ;
Tensor<?> jack = Tensor.create(in);
runner.feed("serving_default_layer1_input", jack);
float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);
for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}
上面的内容被硬编码为输入和输出,但我希望能够读取模型并提供一些信息,以便最终用户可以选择输入和输出等。
我可以使用 python 命令获取输入和输出:saved_model_cli show --dir ./house_price_median_venue --all
我想要做的是通过 Java 获取输入和输出,这样我的代码就不需要执行 python 脚本来获取它们。我可以通过以下方式进行操作:
Graph graph = model.graph();
Iterator<Operation> itr = graph.operations();
while (itr.hasNext()) {
GraphOperation e = (GraphOperation)itr.next();
System.out.println(e);
这将输入和输出都输出为“操作”,但我如何知道它是输入和/或输出? python 工具使用 SignatureDef 但它似乎根本没有出现在 TensorFlow 2.0 java 中。我是否遗漏了一些明显的东西,或者只是 TensforFlow 2.0 Java 库中遗漏了它?
注意,我已使用下面的答案帮助解决了我的问题。这是我的完整代码,以防将来有人喜欢。请注意,这是 TF 2.0,并使用下面提到的 SNAPSHOT。我做了一些假设,但它展示了如何提取输入和输出,然后使用它们来运行模型
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;
public class v2tensor {
public static void main(String[] args) {
try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
TensorInfo input1 = null;
TensorInfo output1 = null;
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
if (input1 == null) {
input1 = input.getValue();
System.out.println(input1.getName());
}
System.out.println(input);
}
Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
if (output1 == null) {
output1=output.getValue();
}
System.out.println(output);
}
try (Session session = savedModel.session()) {
Session.Runner runner = session.runner();
FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );
try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
runner.feed(input1.getName(), jack);
try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) {
TFloat32 data = rezz.data();
data.scalars().forEachIndexed((i, s) -> {
System.out.println(s.getFloat());
} );
}
}
}
} catch (TensorFlowException ex) {
ex.printStackTrace();
}
}
}