使用 TensorFlow Benchmark 对 Keras 模型进行基准测试

2023-11-30

我正在尝试使用 TensorFlow 后端对 Keras 模型构建的推理阶段的性能进行基准测试。我当时想的是张量流基准测试工具是正确的方法。

我已经成功地在桌面上构建并运行了示例tensorflow_inception_graph.pb一切似乎都运行良好。

我似乎无法弄清楚如何将 Keras 模型保存为正确的模型.pb模型。我可以从 Keras 模型获取 TensorFlow Graph,如下所示:

import keras.backend as K
K.set_learning_phase(0)

trained_model = function_that_returns_compiled_model()
sess = K.get_session()
sess.graph # This works

# Get the input tensor name for TF Benchmark
trained_model.input
> <tf.Tensor 'input_1:0' shape=(?, 360, 480, 3) dtype=float32>

# Get the output tensor name for TF Benchmark
trained_model.output
> <tf.Tensor 'reshape_2/Reshape:0' shape=(?, 360, 480, 12) dtype=float32>

我现在一直在尝试以几种不同的方式保存模型。

import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter

model = trained_model
export_path = "path/to/folder"  # where to save the exported graph
export_version = 1  # version number (integer)

saver = tf.train.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
signature = exporter.classification_signature(input_tensor=model.input, scores_tensor=model.output)
model_exporter.init(sess.graph.as_graph_def(), default_graph_signature=signature)
model_exporter.export(export_path, tf.constant(export_version), sess)

这会生成一个文件夹,其中包含一些我不知道如何处理的文件。

我现在将运行基准测试工具,如下所示

bazel-bin/tensorflow/tools/benchmark/benchmark_model \
  --graph=tensorflow/tools/benchmark/what_file.pb \
  --input_layer="input_1:0" \
  --input_layer_shape="1,360,480,3" \
  --input_layer_type="float" \
  --output_layer="reshape_2/Reshape:0"

但无论我尝试使用哪个文件作为what_file.pb我得到了一个Error during inference: Invalid argument: Session was not created with a graph before Run()!


所以我让它发挥作用。只需将张量流图中的所有变量转换为常量,然后保存图定义。

这是一个小例子:

import tensorflow as tf

from keras import backend as K
from tensorflow.python.framework import graph_util

K.set_learning_phase(0)
model = function_that_returns_your_keras_model()
sess = K.get_session()

output_node_name = "my_output_node" # Name of your output node

with sess as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    graph_def = sess.graph.as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
                                                                 sess,
                                                                 sess.graph.as_graph_def(),
                                                                 output_node_name.split(","))
    tf.train.write_graph(output_graph_def,
                         logdir="my_dir",
                         name="my_model.pb",
                         as_text=False)

现在只需调用 TensorFlow Benchmark 工具即可my_model.pb如图所示。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 TensorFlow Benchmark 对 Keras 模型进行基准测试 的相关文章

随机推荐

  • 我应该设计一个主键为 varchar 还是 int 的表?

    我知道这是主观的 但我想了解人们的意见 并希望在设计 sql server 表结构时可以应用一些最佳实践 我个人认为 在固定 最大 长度的 varchar 上键入表是不行的 因为这意味着必须在使用它作为外键的任何其他表上传播相同的固定长度
  • 如何读取/流式传输文件而不将整个文件加载到内存中?

    如何读取任意文件并 逐个 处理它 意味着逐字节或其他一些可以提供最佳读取性能的块大小 而不将整个文件加载到内存中 处理的一个示例是生成文件的 MD5 哈希值 尽管答案可以适用于任何操作 我想拥有或编写这个 但如果我可以获得现有的代码 那就太
  • PDO 使用键作为列名插入数组

    我正在使用 PDO 将 PHP 数组的 POST 内容插入到表中 我正在查看以下代码行 我有一个 必须有更好的方法来做到这一点 的时刻 如果键名与表中的列名匹配 是否有更简单的方法来插入所有键名 代码例如 statement db gt p
  • 下划线:基于多个属性的sortBy()

    我正在尝试根据多个属性对包含对象的数组进行排序 即 如果两个对象之间的第一个属性相同 则应使用第二个属性来比较这两个对象 例如 考虑以下数组 var patients name John roomNumber 1 bedNumber 1 n
  • 与 AVX/AVX2 一起使用的最低 OS X 版本是什么?

    我有一个图像绘制例程 为 SSE SSE2 SSE3 SSE4 1 SSE4 2 AVX 和 AVX2 编译多次 我的程序通过检查 CPUID 标志来动态调度这些二进制变体之一 在 Windows 上 我检查 Windows 版本 如果操作
  • 在 C 中设置位

    我正在尝试执行以下操作 写一个函数setbits x p n y 返回x with n开始于的位 位置p设置到最右边n的位y 留下其他位 不变 我这样尝试但没有得到正确的答案 谁能告诉我哪里错了 unsigned setbits unsig
  • Android 键盘隐藏 EditText

    当我尝试在屏幕底部的 EditText 中写入内容时 软键盘会隐藏 EditText 我该如何解决这个问题 下面是我的 xml 代码 我在片段中使用它
  • 如何检索 Android 中可用/已安装字体的列表?

    在Java中我会做类似的事情 java awt GraphicsEnvironment ge java awt GraphicsEnvironment getLocalGraphicsEnvironment Font fonts ge ge
  • 命令行参数太多 Terraform 计划

    我是地形新手 我正在尝试通过天蓝色管道创建一个简单的存储帐户 但是当我运行管道时 我收到错误 命令行参数太多 我很震惊 我不知道我做错了什么 有人可以帮忙吗 这是我的计划脚本 script terraform plan out plan t
  • 如何使用占位符将列名值作为 SQL 参数传递

    如何使用参数占位符将列名值作为 SQL 参数传递 目标是让这个工作 var sql SELECT FROM Condos WHERE 0 LIKE 1 var sqlData db Query sql choice searchString
  • 如何从java应用程序创建Windows服务

    我刚刚继承了一个java应用程序 需要将其作为服务安装在XP和vista上 自从我以任何形式使用 Windows 以来 已经有大约 8 年了 我从来没有创建过服务 更不用说像 Java 应用程序这样的东西了 我有一个应用程序的 jar 和一
  • Android 连接至已配对的蓝牙耳机

    我想模拟通过 设置 gt 无线 gt 蓝牙 的操作 并以编程方式连接配对的蓝牙耳机 我在 Stackoverflow 和 Google 上进行了一些搜索 两者都表明在 API 级别 11 之前没有可用的解决方案 但是 我有兴趣通过查看 An
  • 使用 python 和 pandas 按季节对数据进行分组

    我想使用 Pandas 和 Python 迭代我的 csv 文件 并按季节对数据进行分组 计算一年中每个季节的平均值 目前 季度脚本为一月至三月 四月至六月等 我希望季节与月份相关联 11 冬季 12 冬季 1 冬季 2 春季 3 春天 4
  • 解析复杂的肥皂响应

    我正在 android 中构建我的第一个应用程序 该应用程序使用 wcf 服务 我正在使用 ksoap2 来解析响应 响应实际上是 C 中定义的对象数组 我这样做了 这非常有帮助guide现在我的问题是我需要使用一个 wcf 服务 它再次返
  • 如何在 OPENGL 中旋转或平移单个对象实例?

    假设我有一个有四个立方体的场景 我该如何说在 OpenGL 中仅旋转 平移其中两个立方体而不使用 glrotatef 和 f gltranslate 更改其他立方体 我不想定义我自己的齐次坐标 像往常一样绘制前两个立方体 推入视图模型矩阵
  • iPhone 上的点击延迟和抑制输入焦点

    iPhone 上的 webkit 浏览器在用户进行触摸和 javascript 获取单击事件之间有 300 毫秒的延迟 发生这种情况是因为浏览器需要检查用户是否进行了双击 我的应用程序不允许缩放 因此双击对我来说毫无用处 有不少人有提出的解
  • 如何安全地调用 vsnprintf() ?

    我正在将一些非常古老 gt 10 年 的 C 代码移植到现代 Linux 我在自定义编写的 vsnprintf 包装器中遇到分段错误 显然它的任务是检测重复的输出字符串并实习它们 char strVPrintf const String f
  • Oracle中的查询通过子查询进行选择

    我的 Oracle 数据库中有下表 CREATE TABLE test flight NUMBER 4 date DATE action VARCHAR2 50 CONSTRAINT pk PRIMARY KEY flight date 以
  • 为什么 Glimpse 仍在运行?

    我瞥见了 defaultRuntimePolicy Off 但它仍然显示这样的错误 Unable to define EFProfiledDbProviderServices class of type GlimpseDbProviderS
  • 使用 TensorFlow Benchmark 对 Keras 模型进行基准测试

    我正在尝试使用 TensorFlow 后端对 Keras 模型构建的推理阶段的性能进行基准测试 我当时想的是张量流基准测试工具是正确的方法 我已经成功地在桌面上构建并运行了示例tensorflow inception graph pb一切似