使用 feed_dict 比使用数据集 API 快 5 倍以上?

2023-11-21

我创建了一个 TFRecord 格式的数据集进行测试。每个条目包含 200 列,名为C1 - C199,每个都是一个字符串列表,和一个label列来表示标签。创建数据的代码可以在这里找到:https://github.com/codescv/tf-dist/blob/8bb3c44f55939fc66b3727a730c57887113e899c/src/gen_data.py#L25

然后我使用线性模型来训练数据。第一种方法如下所示:

dataset = tf.data.TFRecordDataset(data_file)
dataset = dataset.prefetch(buffer_size=batch_size*10)
dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)

features, labels = dataset.make_one_shot_iterator().get_next()    
logits = tf.feature_column.linear_model(features=features, feature_columns=columns, cols_to_vars=cols_to_vars)
train_op = ...

with tf.Session() as sess:
    sess.run(train_op)

完整代码可以在这里找到:https://github.com/codescv/tf-dist/blob/master/src/lr_single.py

当我运行上面的代码时,我得到 0.85 步/秒(批量大小为 1024)。

在第二种方法中,我手动将数据集中的批次获取到 python 中,然后将它们提供给占位符,如下所示:

example = tf.placeholder(dtype=tf.string, shape=[None])
features = tf.parse_example(example, features=tf.feature_column.make_parse_example_spec(columns+[tf.feature_column.numeric_column('label', dtype=tf.float32, default_value=0)]))
labels = features.pop('label')
train_op = ...

dataset = tf.data.TFRecordDataset(data_file).repeat().batch(batch_size)
next_batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    data_batch = sess.run(next_batch)
    sess.run(train_op, feed_dict={example: data_batch})

完整代码可以在这里找到:https://github.com/codescv/tf-dist/blob/master/src/lr_single_feed.py

当我运行上面的代码时,我得到 5 步/秒。这比第一种方法快 5 倍。这是我不明白的,因为从理论上讲,由于数据批次的额外序列化/反序列化,第二个应该更慢。

Thanks!


目前(从 TensorFlow 1.9 开始)使用时存在性能问题tf.data映射和批处理具有大量特征且每个特征具有少量数据的张量。该问题有两个原因:

  1. The dataset.map(parse_tfrecord, ...)转换将执行 O(batch_size * num_columns)创建批次的小操作。相比之下,喂养tf.placeholder() to tf.parse_example()将执行 O(1) 操作来创建相同的批次。

  2. 批量多tf.SparseTensor对象使用dataset.batch()比直接创建相同的要慢得多tf.SparseTensor作为输出tf.parse_example().

对这两个问题的改进正在进行中,并且应该会在 TensorFlow 的未来版本中提供。同时,您可以提高性能tf.data基于管道通过切换顺序dataset.map() and dataset.batch()并重写dataset.map()处理字符串向量,例如基于喂食的版本:

dataset = tf.data.TFRecordDataset(data_file)
dataset = dataset.prefetch(buffer_size=batch_size*10)
dataset = dataset.repeat(num_epochs)

# Batch first to create a vector of strings as input to the map(). 
dataset = dataset.batch(batch_size)

def parse_tfrecord_batch(record_batch):
  features = tf.parse_example(
      record_batch,
      features=tf.feature_column.make_parse_example_spec(
          columns + [
              tf.feature_column.numeric_column(
                  'label', dtype=tf.float32, default_value=0)]))
  labels = features.pop('label')
  return features, labels

# NOTE: Parallelism might not be as useful, because the individual map function now does
# more work per invocation, but you might want to experiment with this.
dataset = dataset.map(parse_tfrecord_batch)

# Add a prefetch at the end to pipeline execution.
dataset = dataset.prefetch(1)

features, labels = dataset.make_one_shot_iterator().get_next()    
# ...

编辑 (2018/6/18): 回答一下评论里的问题:

  1. Why is dataset.map(parse_tfrecord, ...) O(batch_size * num_columns),而不是 O(batch_size)?如果解析需要枚举列,为什么 parse_example 不采用 O(num_columns)?

当您将 TensorFlow 代码包装在Dataset.map()(或其他函数转换)每个输出的恒定数量的额外操作被添加到函数的“返回”值中,并且(在tf.SparseTensor值)将它们“转换”为标准格式。当你直接传递输出时tf.parse_example()对于模型的输入,不会添加这些操作。虽然它们都是非常小的操作,但执行如此多的操作可能会成为瓶颈。 (从技术上讲,解析does take O(batch_size * num_columns) time,但解析中涉及的常量比执行操作要小得多。)

  1. 为什么要在管道末尾添加预取?

当您对性能感兴趣时,这几乎总是最好的选择,并且它应该提高管道的整体性能。有关最佳实践的更多信息,请参阅性能指南tf.data.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 feed_dict 比使用数据集 API 快 5 倍以上? 的相关文章

随机推荐

  • 如何获取 pyspark 中 groupBy 之后每个计数的总数百分比?

    给定以下数据框 import findspark findspark init from pyspark sql import SparkSession spark SparkSession builder master local app
  • 不再能够创建 bacpac:SQL70015:SQL Azure 上不支持已弃用的功能“字符串文字作为列别名”

    今天我们遇到了一个严重错误 我们无法再为实时 Azure 生产数据库创建 bacpac 文件 到目前为止一切正常 突然我们开始遇到以下错误 服务操作期间遇到错误 无法从指定数据库提取包 错误 SQL70015 SQL Azure 不支持已弃
  • 为什么 onResume() 似乎被调用了两次?

    我在这里有我的活动课程 public class CameraActivity extends Activity private Camera mCamera private CameraPreview mPreview public vo
  • 如何在 mayavi2 中缩放 x 轴和 y 轴?

    我想使用 mayavi mlab surf 用 mayavi2 绘制 3 d 绘图 该函数有一个名为 warp scale 的参数 可用于缩放 z 轴 我正在寻找类似的东西 但适用于 x 和 y 轴 我可以通过将 x 和 y 数组相乘 然后
  • python 中数组的导数?

    目前我有两个 numpy 数组 x and y大小相同 我想编写一个函数 可能调用 numpy scipy 函数 如果存在的话 def derivative x y n 1 something return result where res
  • m 的静态声明位于非静态声明之后

    我正在尝试一个小例子来了解静态外部变量及其用途 静态变量是局部范围的 外部变量是全局范围的 静态5 c include
  • 如何在 Chrome 扩展浏览器操作中显示 Google reCAPTCHA v2?

    我正在构建一个 Chrome 扩展程序 它与我希望使用 Google recatcha 保护的 API 进行交互 因为我打算让它在 Chrome 扩展程序之外使用 API 端正在工作 正确验证了 Google 的 recapcha 响应 但
  • SerialVersionUID 是如何计算的

    当我在 Eclipse 中创建 Java 类时 它实现了Serializable界面 我收到警告 可序列化类 ABCD 未声明静态final long 类型的serialVersionUID 字段 因此 当我单击警告时 我会在 Eclips
  • 从具有自定义字段的表单创建 mailto

    我有一个包含 3 个字段 姓名 电子邮件和消息 的 HTML 表单 我想使用这 3 个字段创建自定义 mailto 但我不想创建如下所示的固定内容 a href Send a mail a 这可能吗 如果不是 我是否有其他方法来制作简单的处
  • 使用 npm 安装 bcrypt 时出错

    我无法安装bcrypt using npm在我的机器上 因为我遇到以下错误 我一直在解决这个问题 但运气不佳 您能否建议任何步骤来诊断或解决问题 以便我可以运行npm install bcrypt成功地 Someones Macbook n
  • 如何以编程方式(合法地)获取街道地址的经度和纬度

    据说 可以从谷歌地图或某些此类服务中获取此信息 仅美国地址是不够的 您正在寻找的术语是地理编码 是的 谷歌确实提供了这项服务 新的V3 API http code google com apis maps documentation geo
  • 如何追踪这个? AttributeError:“NoneType”对象在 makemigrations 期间没有属性“is_relation”

    自昨天以来我第二次遇到令人困惑的错误 上次我只是扁平化了整个迁移 但我从未真正找到导致问题的原因 所以当我尝试为我的 python 项目进行迁移时就会出现这种情况 我应该在哪里寻找错误 我觉得这实际上与迁移无关 而是与views py或mo
  • “核心语言”是什么意思?

    在表中关于这一页从 GCC 文档来看 其中一项 大约在表格的中间 仅被列为 核心语言 这意味着什么 语言的哪些部分不会被包括在内 标准库是该语言的一部分 为了表达仅与语法规则 语义规则等相关但与库无关的语言子集 人们使用术语核心语言 例如
  • 如何从 Android 手机获取时区?

    我想在单击按钮时从 Android 手机获取时区 您是否尝试过使用TimeZone getDefault 大多数应用程序都会使用时区 getDefault 它返回一个基于时区的 程序运行所在的时区 Ref http developer an
  • Django仅在生产环境中使用私有S3存储

    我已将 django REST API 设置为在调试模式下使用本地存储 在生产环境中使用 S3 存储 这对于公共文件很有效 因为我覆盖了DEFAULT FILE STORAGE像这样 if IS DEBUG DEFAULT FILE STO
  • 接受多个 Id 值的 T-SQL 存储过程

    有没有一种优雅的方法来处理将 id 列表作为参数传递给存储过程 例如 我希望我的存储过程返回部门 1 2 5 7 20 过去 我传递了一个逗号分隔的 id 列表 如下面的代码 但感觉这样做真的很脏 我认为 SQL Server 2005 是
  • .NET 中的 C# 类何时调用析构函数?

    比如说 我有自己的 C 类 定义如下 public class MyClass public MyClass Do the work MyClass Destructor 然后我从 ASP NET 项目创建类的实例 如下所示 if true
  • Google Chrome .dev 无法通过 http 工作 [重复]

    这个问题在这里已经有答案了 自上次更新以来谷歌浏览器 63 0 3239 84 the dev我的本地开发计算机的域不再工作 因为浏览器强制 URL 通过 https 并且我的本地计算机上没有 sicure 证书 有没有办法让它与 dev
  • 64 位 iOS 设备上的 asm("trap")

    在我自己开发的断言宏中 我一直在 iOS 设备上使用 asm trap 或在 iOS 模拟器上使用 asm int3 来中断调试器 然而 在设备的 64 位版本中 我得到了陷阱指令的 无法识别的指令助记符 有与arm64相当的吗 像 bui
  • 使用 feed_dict 比使用数据集 API 快 5 倍以上?

    我创建了一个 TFRecord 格式的数据集进行测试 每个条目包含 200 列 名为C1 C199 每个都是一个字符串列表 和一个label列来表示标签 创建数据的代码可以在这里找到 https github com codescv tf