Tensorflow 2.0 和 Java API

2023-12-23

(注意,我已经解决了我的问题并将代码发布在底部)

我正在使用 TensorFlow,后端处理必须在 Java 中进行。我从 中取出了其中一个模型https://developers.google.com/machine-learning/crash-course https://developers.google.com/machine-learning/crash-course并使用 tf.saved_model.save(my_model,"house_price_median_venue") 保存它(使用 docker 容器)。我复制了模型并将其加载到 Java 中(使用从源代码构建的 2.0 内容,因为我使用的是 Windows)。 我可以加载模型并运行它:

   try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
    try (Session session = model.session()) {
        Session.Runner runner = session.runner();
        float[][] in = new float[][]{ {2.1518f} } ;

        Tensor<?> jack = Tensor.create(in);
        runner.feed("serving_default_layer1_input", jack);

        float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);

        for (int i = 0; i < probabilities.length; ++i) {
            System.out.println(String.format("-- Input #%d", i));
            for (int j = 0; j < probabilities[i].length; ++j) {
              System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
            }
          }
    }
 }

上面的内容被硬编码为输入和输出,但我希望能够读取模型并提供一些信息,以便最终用户可以选择输入和输出等。

我可以使用 python 命令获取输入和输出:saved_model_cli show --dir ./house_price_median_venue --all

我想要做的是通过 Java 获取输入和输出,这样我的代码就不需要执行 python 脚本来获取它们。我可以通过以下方式进行操作:

 Graph graph = model.graph();
    Iterator<Operation> itr = graph.operations();
    while (itr.hasNext()) {
        GraphOperation e = (GraphOperation)itr.next();
        System.out.println(e);

这将输入和输出都输出为“操作”,但我如何知道它是输入和/或输出? python 工具使用 SignatureDef 但它似乎根本没有出现在 TensorFlow 2.0 java 中。我是否遗漏了一些明显的东西,或者只是 TensforFlow 2.0 Java 库中遗漏了它?

注意,我已使用下面的答案帮助解决了我的问题。这是我的完整代码,以防将来有人喜欢。请注意,这是 TF 2.0,并使用下面提到的 SNAPSHOT。我做了一些假设,但它展示了如何提取输入和输出,然后使用它们来运行模型

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;

public class v2tensor {
    public static void main(String[] args) {
     try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
        SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
        TensorInfo input1 = null;
        TensorInfo output1 = null;
        Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
        for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
            if (input1 == null) {
                input1 = input.getValue();
                System.out.println(input1.getName());
            }
            System.out.println(input);
        }
        Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
        for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
            if (output1 == null) {
                output1=output.getValue();
            }
            System.out.println(output);
        }

        try (Session session = savedModel.session()) {
            Session.Runner runner = session.runner();
            FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );

            try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
                runner.feed(input1.getName(), jack);
                try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) { 
                    TFloat32 data = rezz.data();
                    data.scalars().forEachIndexed((i, s) -> {
                        System.out.println(s.getFloat());
                    }   );
                }
            }
        }
    } catch (TensorFlowException ex) {
        ex.printStackTrace();   
    }
    }
}

您需要做的是阅读SavedModelBundle元数据作为MetaGraphDef,从那里您可以检索输入和输出名称SignatureDef,就像在Python中一样。

在 TF Java 1.*(即您在示例中使用的客户端)中,原型定义无法从tensorflow工件,您需要添加依赖项org.tensorflow:proto以及反序列化结果SavedModelBundle.metaGraphDef() https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/SavedModelBundle#public-byte-metagraphdef into a MetaGraphDef proto.

在 TF Java 2.* 中(新客户端实际上只能作为来自here https://github.com/tensorflow/java/),原型立即出现,因此您只需调用此行即可检索正确的SignatureDef:

savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 2.0 和 Java API 的相关文章

  • 如何使用 Java 和 Selenium WebDriver 在 C 目录中创建文件夹并需要将屏幕截图保存在该目录中?

    目前正在与硒网络驱动程序和代码Java 我有一种情况 我需要在 C 目录中创建一个文件夹 并在该文件夹中创建我通过 selenium Web 驱动程序代码拍摄的屏幕截图 它需要存储在带有时间戳的文件夹中 如果我每天按计划运行脚本 所有屏幕截
  • 在 java 类和 android 活动之间传输时音频不清晰

    我有一个android活动 它连接到一个java类并以套接字的形式向它发送数据包 该类接收声音数据包并将它们扔到 PC 扬声器 该代码运行良好 但在 PC 扬声器中播放声音时会出现持续的抖动 中断 安卓活动 public class Sen
  • 如何找到给定字符串的最长重复子串

    我是java新手 我被分配寻找字符串的最长子字符串 我在网上研究 似乎解决这个问题的好方法是实现后缀树 请告诉我如何做到这一点或者您是否有任何其他解决方案 请记住 这应该是在 Java 知识水平较低的情况下完成的 提前致谢 附 测试仪字符串
  • 在 HTTPResponse Android 中跟踪重定向

    我需要遵循 HTTPost 给我的重定向 当我发出 HTTP post 并尝试读取响应时 我得到重定向页面 html 我怎样才能解决这个问题 代码 public void parseDoc final HttpParams params n
  • INSERT..RETURNING 在 JOOQ 中不起作用

    我有一个 MariaDB 数据库 我正在尝试在表中插入一行users 它有一个生成的id我想在插入后得到它 我见过this http www jooq org doc 3 8 manual sql building sql statemen
  • 控制Android的前置LED灯

    我试图在用户按下某个按钮时在前面的 LED 上实现 1 秒红色闪烁 但我很难找到有关如何访问和使用前置 LED 的文档 教程甚至代码示例 我的意思是位于 自拍 相机和触摸屏附近的 LED 我已经看到了使用手电筒和相机类 已弃用 的示例 但我
  • 列出jshell中所有活动的方法

    是否有任何命令可以打印当前 jshell 会话中所有新创建的方法 类似的东西 list但仅适用于方法 您正在寻找命令 methods all 它会打印所有方法 包括启动 JShell 时添加的方法 以及失败 被覆盖或删除的方法 对于您声明的
  • 反射找不到对象子类型

    我试图通过使用反射来获取包中的所有类 当我使用具体类的代码 本例中为 A 时 它可以工作并打印子类信息 B 扩展 A 因此它打印 B 信息 但是当我将它与对象类一起使用时 它不起作用 我该如何修复它 这段代码的工作原理 Reflection
  • Spring @RequestMapping 带有可选参数

    我的控制器在请求映射中存在可选参数的问题 请查看下面的控制器 GetMapping produces MediaType APPLICATION JSON VALUE public ResponseEntity
  • 如何为俚语和表情符号构建正则表达式 (regex)

    我需要构建一个正则表达式来匹配俚语 即 lol lmao imo 等 和表情符号 即 P 等 我按照以下示例进行操作http www coderanch com t 497238 java java Regular Expression D
  • Java TestNG 与跨多个测试的数据驱动测试

    我正在电子商务平台中测试一系列商店 每个商店都有一系列属性 我正在考虑对其进行自动化测试 是否有可能有一个数据提供者在整个测试套件中提供数据 而不仅仅是 TestNG 中的测试 我尝试不使用 testNG xml 文件作为机制 因为这些属性
  • 在两个活动之间传输数据[重复]

    这个问题在这里已经有答案了 我正在尝试在两个不同的活动之间发送和接收数据 我在这个网站上看到了一些其他问题 但没有任何问题涉及保留头等舱的状态 例如 如果我想从 A 类发送一个整数 X 到 B 类 然后对整数 X 进行一些操作 然后将其发送
  • JRE 系统库 [WebSphere v6.1 JRE](未绑定)

    将项目导入 Eclipse 后 我的构建路径中出现以下错误 JRE System Library WebSphere v6 1 JRE unbound 谁知道怎么修它 右键单击项目 特性 gt Java 构建路径 gt 图书馆 gt JRE
  • 使用Caliper时如何指定命令行?

    我发现 Google 的微型基准测试项目 Caliper 非常有趣 但文档仍然 除了一些示例 完全不存在 我有两种不同的情况 需要影响 JVM Caliper 启动的命令行 我需要设置一些固定 最好在几个固定值之间交替 D 参数 我需要指定
  • Java Integer CompareTo() - 为什么使用比较与减法?

    我发现java lang Integer实施compareTo方法如下 public int compareTo Integer anotherInteger int thisVal this value int anotherVal an
  • 玩!框架:运行“h2-browser”可以运行,但网页不可用

    当我运行命令时activator h2 browser它会使用以下 url 打开浏览器 192 168 1 17 8082 但我得到 使用 Chrome 此网页无法使用 奇怪的是它以前确实有效 从那时起我唯一改变的是JAVA OPTS以启用
  • 使用 JMF 创建 RTP 流时出现问题

    我正处于一个项目的早期阶段 需要使用 RTP 广播DataStream创建自MediaLocation 我正在遵循一些示例代码 该代码目前在rptManager initalize localAddress 出现错误 无法打开本地数据端口
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • 当我从 Netbeans 创建 Derby 数据库时,它存储在哪里?

    当我从 netbeans 创建 Derby 数据库时 它存储在哪里 如何将它与项目的其余部分合并到一个文件夹中 右键单击Databases gt JavaDB in the Service查看并选择Properties This will
  • 节拍匹配算法

    我最近开始尝试创建一个移动应用程序 iOS Android 它将自动击败比赛 http en wikipedia org wiki Beatmatching http en wikipedia org wiki Beatmatching 两

随机推荐

  • 为什么是 em 而不是 px?

    我听说你应该在样式表中定义尺寸和距离em而不是以像素为单位 所以问题是我为什么要使用em代替px在 CSS 中定义样式时 有一个很好的例子来说明这一点吗 说一个比另一个更好的选择是错误的 或者两者都不会在规范中给出自己的目的 甚至值得注意的
  • 为什么所有 NUL 都从我的脚本中删除?

    它看起来像 bash 还有 dash 从我的脚本中过滤掉任何 ASCII NUL printf test 000a echo test sh 1 printf test 001a echo test sh 2 printf ec 000ho
  • Heroku 混合内容 HTTPS/HTTP 问题

    我将应用程序部署到 Heroku 但在 Chrome 控制台中不断收到此错误 bundle js 11892 Mixed Content The page at https herokuapp com login was loaded ov
  • 跨线程编组 COM 接口的首选方法是什么?

    与 CoMarshalInterThreadInterfaceInStream 和 CoGetInterfaceAndReleaseStream 相比 使用 GIT 跨线程编组 COM 接口有哪些优点 缺点 是否有充分的理由选择一种方法而不
  • 解析 @username 的帖子

    我建立了一个类似 Twitter 的 replies 允许用户通过用户每日帖子相互联系 类似于 stackoverflow 以此作为指导https github com kltcalamay sample app compare origi
  • Spark DataFrame 和重命名多列 (Java)

    有没有更好的方法可以同时为给定 SparkSQL 的所有或多个列添加前缀或重命名DataFrame比多次调用dataFrame withColumnRenamed 一个例子是 如果我想检测更改 使用完整外连接 然后我就剩下两个了DataFr
  • Tensorflow 对象检测 api 验证数据大小

    我正在从对象检测 API 运行教程 并将 Oxford 数据集与 ResNet Faster RCNN 一起使用 当我通过运行 eval py 评估经过训练的模型时 Tensorboard 返回大约 0 95 的平滑精度值 我的问题是它评估
  • 我可以在 Django generic.ListView 中拥有多个列表吗?

    作为 Django 初学者 我正在研究 django 文档提供的教程 网址为https docs djangoproject com en 1 5 intro tutorial04 https docs djangoproject com
  • Android - 从收到的短信中获取日期和时间

    我正在开发一个 Android 应用程序 我需要在其中保存发件人 短信正文 日期和时间收到短信 现在我可以捕获消息正文和发件人 但我无法获取短信的日期和时间 即使我查看了 stackoverflow 中的一些帖子 但它们都没有捕获日期和时间
  • NSView 子类中的鼠标单击事件

    我有一个 NSView 子类 它使用 OpenGL 上下文进行初始化 并具有一堆鼠标事件处理 onMouseDown 等 我有一个使用它的应用程序 它有一个包含视图的主 Cocoa 窗口 并且所有鼠标事件都正常工作 但是 我现在尝试在另一个
  • 自动调整 UICollectionView 高度以适应其内容大小

    我有一个 UICollectionView 一个在集合视图中创建新单元格的按钮 我希望 UICollectionView 根据其内容大小调整其大小 当有一两个单元格时 UICollectionView 很短 如果有很多单元格 UIColle
  • 对话框服务内容中的换行符被忽略

    在剑道对话服务窗口的内容中插入换行符 rogress 中的 Kendo Angular 6 对话框 换行符将被忽略 尝试了 html 元素 br 和 n n 例如它在连续一行中显示字符文本 n n n const dialog Dialog
  • 如何向 CRM 2011 进行身份验证?

    我想建立一个简单的网站 客户可以在其中下订单和查看产品 此数据是我的 Microsoft Dynamics CRM 2011 环境 该数据是特定于客户的 因此我需要有关登录用户的信息 用户凭据存储在 CRM 2011 中 使用这些凭据 用户
  • Jython 的内存限制

    如何为我的 Jython 程序设置 JVM 内存限制 Java 的 Xmx 选项 我明白 Jython2 5引入 J 选项以便将选项发送到 JVM jython J Xmx8000m 但是 我必须与java1 6 0 23 上的 Jytho
  • android OAuth-2.0 google 使用 webview 登录获取用户信息

    我正在创建允许用户使用 facebook 或 google 帐户登录的应用程序 他们按下 登录 按钮 然后系统会要求他们使用 facebook 或 google 登录 当他们选择其中之一时 会弹出网络视图 问题是谷歌身份验证 阅读了一些文章
  • 清理带有标题的 URL 的最佳方法是什么

    清理 URL 的最佳方法是什么 我正在寻找这样的网址 什么是最好的头痛药物 我当前的代码 public string CleanURL string str str str Replace str str Replace str str R
  • 在不知道 T 类型的情况下获取 Task 的结果 [重复]

    这个问题在这里已经有答案了 我正在开发一个 C 系统 一个类有一个返回 a 的函数System Threading Tasks Task对象并具有属性System Type返回类型 当 ReturnType 为 null 时 我知道该方法返
  • Python 中的 MySQL 连接器不允许 LOAD DATA INFILE 语法

    我正在尝试将文本文件发送到 MySQL 数据库 我正在尝试使用 python 3 2 中的 mysql 连接器来执行此操作 问题与 LOAD DATA INFILE 语法有关 你可以在上面找到我的代码 我的第一个问题是有没有办法解决这个问题
  • Android:使用 Intent setResults 来回传递数据

    我正在为 Android 创建一个基于 GPS 的应用程序 有 2 个活动 Main 和 LocNames Main 显示我的地图 LocNames 是获取用户想要的源和目的地 我想在用户从菜单中选择 LocNames 时启动 LocNam
  • Tensorflow 2.0 和 Java API

    注意 我已经解决了我的问题并将代码发布在底部 我正在使用 TensorFlow 后端处理必须在 Java 中进行 我从 中取出了其中一个模型https developers google com machine learning crash