Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果

2024-03-13

我正在尝试按照以下方法提高我的模型训练性能使用 tf.data API 获得更好的性能 https://www.tensorflow.org/guide/data_performance指导方针。然而,我观察到使用的性能.cache()如果与没有的相同设置相比,几乎相同甚至更糟.cache().

datafile_list = load_my_files()
RAW_BYTES = 403*4
BATCH_SIZE = 32

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=datafile_list, record_bytes=RAW_BYTES, num_parallel_reads=10, buffer_size=1024*RAW_BYTES)
raw_dataset = raw_dataset.map(tf.autograph.experimental.do_not_convert(decode_and_prepare),
    num_parallel_calls=tf.data.AUTOTUNE)
raw_dataset = raw_dataset.cache()
raw_dataset = raw_dataset.shuffle(buffer_size=1024)
raw_dataset = raw_dataset.batch(BATCH_SIZE)
raw_dataset = raw_dataset.prefetch(tf.data.AUTOTUNE)

数据在datafile_list保留 9.92GB,相当适合系统可用的总物理 RAM (100GB)。系统交换已禁用。

通过使用数据集训练模型:

model = build_model()
model.fit(raw_dataset, epochs=5, verbose=2)

结果是:

Epoch 1/5
206247/206247 - 126s - loss: 0.0043 - mae: 0.0494 - mse: 0.0043
Epoch 2/5
206247/206247 - 125s - loss: 0.0029 - mae: 0.0415 - mse: 0.0029
Epoch 3/5
206247/206247 - 129s - loss: 0.0027 - mae: 0.0397 - mse: 0.0027
Epoch 4/5
206247/206247 - 125s - loss: 0.0025 - mae: 0.0386 - mse: 0.0025
Epoch 5/5
206247/206247 - 125s - loss: 0.0024 - mae: 0.0379 - mse: 0.0024

这个结果令人沮丧。由docs https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache:

第一次迭代数据集时,其元素将被缓存在指定文件或内存中。后续迭代将使用缓存的数据。

并从本指南 https://www.tensorflow.org/datasets/performances#caching_the_dataset:

迭代此数据集时,由于缓存,第二次迭代将比第一次迭代快得多。

然而,所有历元所花费的时间几乎相同。此外,在训练过程中,CPU 和 GPU 的使用率都非常低(见下图)。

通过注释掉该行raw_dataset = raw_dataset.cache()结果没有显示任何显着差异:

Epoch 1/5
206067/206067 - 129s - loss: 0.0042 - mae: 0.0492 - mse: 0.0042
Epoch 2/5
206067/206067 - 127s - loss: 0.0028 - mae: 0.0412 - mse: 0.0028
Epoch 3/5
206067/206067 - 134s - loss: 0.0026 - mae: 0.0393 - mse: 0.0026
Epoch 4/5
206067/206067 - 127s - loss: 0.0024 - mae: 0.0383 - mse: 0.0024
Epoch 5/5
206067/206067 - 126s - loss: 0.0023 - mae: 0.0376 - mse: 0.0023

正如文档中指出的,我的期望是使用缓存会导致训练时间更快。我想知道我做错了什么。

附件

使用缓存进行训练期间的 GPU 使用情况:

训练期间没有缓存的 GPU 使用情况:

使用缓存进行训练期间的系统统计信息(内存、CPU 等):

训练期间没有缓存的系统统计信息(内存、CPU 等):


只是使用 Google Colab 进行的一个小观察。根据docs https://www.tensorflow.org/api_docs/python/tf/data/Dataset?version=nightly#cache:

注意:为了最终确定缓存,必须完整迭代输入数据集。否则,后续迭代将不会使用缓存数据。

And

注意:缓存每次都会产生完全相同的元素 迭代数据集。如果您希望随机化迭代 order,确保在调用cache之后调用shuffle。

我确实注意到在事先使用缓存和迭代数据集时存在一些差异。这是一个例子。

准备数据:

import random
import struct
import tensorflow as tf
import numpy as np

RAW_N = 2 + 20*20 + 1

bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
  f.write(struct.pack('1612i', *bytess))
def decode_and_prepare(register):
  register = tf.io.decode_raw(register, out_type=tf.float32)
  inputs = register[2:402]
  label = tf.random.uniform(()) + register[402:]
  return inputs, label

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin']*7000, record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)

火车模型without预先缓存和迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).batch(32).prefetch(tf.data.AUTOTUNE)
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 4s 3ms/step - loss: 0.1425
Epoch 2/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 4s 3ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41be037d0>

训练模型with缓存但是no迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).cache().batch(32).prefetch(tf.data.AUTOTUNE)
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 4s 2ms/step - loss: 0.1428
Epoch 2/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 2s 3ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41fa87810>

训练模型with缓存和迭代:

total_data_entries = len(list(raw_dataset.map(lambda x, y: (x, y))))
train_ds = raw_dataset.shuffle(buffer_size=total_data_entries).cache().batch(32).prefetch(tf.data.AUTOTUNE)
_ = list(train_ds.as_numpy_iterator()) # iterate dataset beforehand
inputs = tf.keras.layers.Input((400,))
x = tf.keras.layers.Dense(200, activation='relu', kernel_initializer='normal')(inputs)
x = tf.keras.layers.Dense(100, activation='relu', kernel_initializer='normal')(x)
outputs = tf.keras.layers.Dense(1, kernel_initializer='normal')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(train_ds, epochs=5)
Epoch 1/5
875/875 [==============================] - 3s 3ms/step - loss: 0.1427
Epoch 2/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0841
Epoch 3/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 4/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
Epoch 5/5
875/875 [==============================] - 2s 2ms/step - loss: 0.0840
<keras.callbacks.History at 0x7fc41ac9c850>

结论:数据集的缓存和先前迭代似乎对训练有影响,但在本例中仅使用了 7000 个文件。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果 的相关文章

  • 类型错误:object.__new__(int) 不安全,请使用 int.__new__()

    阅读本文时 Python 中的元类是什么 https stackoverflow com questions 100003 what is a metaclass in python 我正在学习使用 new 使用以下代码片段 class a
  • 出现异常时进行截图

    嘿 有没有一种方法可以在异常 任何异常 时捕获屏幕截图 我的 失败 解决方案位于BaseTestCase unittest TestCase子类 class BaseTestCase unittest TestCase classmetho
  • 如何选择单选按钮?

    我在用mechanize我正在尝试从单选按钮列表中选择一个按钮 该列表有 5 项 如何选择第一项 文档没有帮助我 gt gt gt br form
  • OpenPyXL - 如何查询单元格边框?

    python 和 openpyxl 都是新的 编写一个 py 脚本来遍历大量 Excel 工作簿 工作表 并且需要找到由边框格式标识的某些单元格 我在网上看到几个关于如何设置单元格边框的示例 但我需要阅读它们 具体来说 当表内的数据不一致但
  • 为什么 Python 中的无分支函数和内置函数速度较慢?

    我发现了 2 个无分支函数 它们可以在 python 中查找两个数字的最大值 并将它们与 if 语句和内置 max 函数进行比较 我认为无分支或内置函数将是最快的 但最快的是 if 语句函数 有人知道这是为什么吗 以下是功能 If 语句 2
  • 合并一个对(元组)列表?

    从链接对的列表中 我想将这些对组合成公共 ID 组 这样我就可以将 group ids 写回数据库 例如 UPDATE table SET group n WHERE id IN Example 1 2 3 4 1 5 6 3 7 8 be
  • Python 中字典的合并层次结构

    我有两本词典 而我想做的事情有点奇怪 基本上 我想合并它们 这很简单 但它们是字典的层次结构 我想以这样的方式合并它们 如果字典中的项目本身就是字典并且存在于两者中 我也想合并这些字典 如果它不是字典 我希望第二个字典中的值覆盖第一个字典中
  • 字典键中的通配符

    假设我有一本字典 rank dict V 1 A 2 V 3 A 4 正如您所看到的 我在一个 V 的末尾添加了一个 虽然 3 可能只是 V 的值 但我想要 V1 V2 V2234432 等的另一个密钥 我想检查它 checker V30
  • python下安装xgboost 32位msys失败

    尝试安装 xgboost 失败 Windows 和企业版版本为 Anaconda 2 1 0 64 位 我该如何继续 我一直在使用 R 似乎从 RStudio 在 R 中安装新包相当容易 但在间谍程序中则不然 因为我需要进入命令窗口来执行此
  • 如何在Python中使用内联正则表达式修饰符[重复]

    这个问题在这里已经有答案了 我有一个正则表达式 n DOCUMENTATION n n n 2 s 女巫我正在尝试处理这样的一些文件 usr bin python coding utf 8
  • Pythonwinsound,ASYNC 标志不起作用?

    我正在使用 python 3 5 我试图在继续执行脚本的同时播放声音 根据https docs python org 3 5 library winsound html https docs python org 3 5 library w
  • Keras CNN 回归模型损失低,准确度为 0

    我在 keras 中遇到这个 NN 回归模型的问题 我正在研究一个汽车数据集 以根据 13 个维度预测价格 简而言之 我已将其读取为 pandas 数据帧 将数值转换为浮点数 缩放值 然后对分类值使用 one hot 编码 这创建了很多新列
  • 有没有比 ` except: pass` 更简洁的替代方案?

    我有一个函数 可以按偏好顺序返回多个组的随机成员 事情是这样的 def get random foo or bar I d rather have a foo than a bar if there are foos return get
  • 如何在 django-rest-framework 查询集响应中添加注释数据?

    我正在为查询集中的每个项目生成聚合 def get queryset self from django db models import Count queryset Book objects annotate Count authors
  • 在 Django 中使用 path() 找不到 404

    我刚刚查看 django 并尝试通过视图列出书籍id作为 URL 的参数books urls py 但出现 404 页面未找到错误 当我在浏览器中输入此网址时 我没有发现网址有什么问题 http 192 168 0 106 8000 boo
  • Django model.foreignKey 并返回 self.text 错误

    所以我正在 Django 中处理 model py 但遇到了 2 个 pylint 错误 我不明白为什么 这是 pylint 的问题还是我在代码中做错了什么 E1120 No value for argument on delete in
  • 在 pandas DataFrame 中使用比较列表的问题

    我在 pandas 中有一个 DataFrame 其列类型之一是 int 上的列表 如下所示 df pandas DataFrame 1 2 3 4 5 6 7 8 9 10 columns a b c d gt gt gt df a b
  • 通过 Selenium 和 python 切换到 iframe

    我如何在硒中切换到这个 iframe 只知道 您可以使用 XPath 来定位 iframe driver find element by xpath iframe name Dialogue Window Then switch to th
  • 在绘图中的线间隙之间添加注释

    I have a graph like this 而不是在上面的日子symbol 我想知道是否有办法可以在行之间添加此注释 从一个点到另一个点 如果以防万一 这可能是重复的 我深表歉意 This is my expected output
  • 我可以在某些网格中打印带有颜色的 pandas 数据框吗?

    我有一个 pandas DataFrame 我想突出显示一些数据 例如 In 1 import pandas as pd In 2 import numpy as np In 3 df pd DataFrame np reshape ran

随机推荐

  • 如何将具有前端 SPA 的 Azure CDN 和具有 .Net Core WebApi 的 Azure WebApp 配置到同一自定义域?

    我想拥有https example com https example com作为我设置的 Azure CDN 的自定义域 并且https example com api https example com api作为其余 api 端点来捕
  • 组对组划分

    数据集 date bal 1 31 2013 10 1 31 2013 11 1 31 2013 12 1 31 2013 13 1 31 2013 14 2 28 2013 20 2 28 2013 30 2 28 2013 40 2 2
  • 异步 P/Invoke 调用

    我正在为机器人控制器开发一个包装库 该库主要依赖于 P Invoke 调用 然而 机器人的许多功能 例如归位或移动 需要相当长的时间 并且在运行时会进行线程锁定 所以我想知道如何以异步方式包装功能 这样调用就不会阻塞我的 UI 线程 到目前
  • 如何链接到 rustdoc 中的其他 fns/structs/enums/traits?

    我正在构建一个 Rust 库 并想对其进行一些改进 在 rustdoc 中 我有时想link文档中库的其他部分 例如fns traits or structs 官方语法是什么 As of 铁锈 1 48 https github com r
  • Django 反序列化错误安装 Fixture 时出现问题

    Traceback most recent call last File Users sparshkedia Desktop task venv lib python3 6 site packages django core seriali
  • 如何对这个哈希数组进行分组?

    我有这个哈希数组 name Ben age 18 name David age 19 name Sam age 18 我需要将它们分组age 所以他们最终会变成这样 18 name Ben age 18 name Sam age 18 19
  • NestJs中带有多个参数的@Get DTO

    我正在尝试在 NestJS 中创建一个可通过 GET HTTP 请求访问的控制器操作 该请求接收两个参数 但由于某种原因它们未定义 如何修复它 Get login login Param params LoginUserDto consol
  • 在 Tumblr 上每 3 个帖子添加内容

    我想知道是否有办法在每个页面上的第 3 篇文章之后放置内容 以便我可以渲染一些内容 我在 tumblr 主题 API 上没有找到任何内容 带有 API 的特定帖子 如果您使用 API 来收集 附加帖子 则需要您完成此操作 一个简单的循环 计
  • “我们很抱歉,但有些不对劲。”部署到 Heroku 后

    我制作了一个小型应用程序 用户可以在其中登录 退出 创建等等 我使用 mySQL 作为数据库 并且在本地环境中一切正常 但是当我将其部署到heroku并迁移数据库等之后 heroku版本不起作用 当我追踪日志时我得到了这个 2011 10
  • 仅对单个类禁用 Linq to SQL 类中的自动复数化

    我有一个带有不规则复数的表名 复数与单数相同 有没有办法禁用该单个表的自动复数 Account DB Accounts 同时保留其他表的功能 您需要禁用 LINQ to SQL 设计器的复数表名称 为此 请导航至 工具 gt 选项 gt 数
  • 使用本地 WSDL 文件生成 Metro 客户端

    我之前使用 wsimport 生成了 Metro 客户端 但在这种情况下 WSDL 是通过 https 访问的 我的命令看起来像这样 wsimport https service net services Service wsdl d C
  • Ubuntu:按 Super+L 时不要锁定屏幕 [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 Whenever I press Super L or Win L on my Ubuntu 14 04 Desktop the scre
  • 按值字母顺序对 Javascript 对象进行排序

    我有一个 JS 对象如下 var obj 00 11 22 33 44 55 AddressB 66 77 88 99 AA BB AddressA 55 44 33 22 11 00 AddressC AA BB CC DD EE FF
  • Apache Kafka 主题名称限制有哪些?

    我刚刚尝试创建一个 Kafka 主题 user created 并在 Kafka 日志中看到此错误 Invalid character in value part of property 我用谷歌搜索发现 在邮件列表中 人们正在谈论弃用 a
  • React Native 后台计时器永远不会停止

    我正在构建一个应用程序 它有一个计时器 可以在计时器处于活动状态时请求地理位置 对于我正在使用的计时器反应本机背景计时器 https github com ocetnik react native background timer 这是可行
  • 调用 sp_rename 时使用变量

    我尝试制作一个存储过程 它将 删除主键 重命名设置主键的列名 创建新的主键 我正在努力解决第 2 点 我正在尝试将列重命名为sp rename将参数传递给存储过程 如下所示 EXEC sp rename SCHEMA TABLE ID Id
  • 为什么我运行 python manage.py runserver 时有两个进程

    wenzhixue 80384 0 4 1 1 2464788 22584 s001 S 10 37AM 0 01 06 usr bin python manage py runserver 0 0 0 0 8000 wenzhixue 8
  • 如何处理大量浮点数据?

    我们有一个二进制文件 其中包含大量float数据 约80MB 我们需要在 Java 应用程序中处理它 数据来自医疗扫描仪 一个文件包含来自一个文件的数据Rotation One Rotation包含 960Views One View包含
  • 为构建器配置 lombok

    我想避免多个构造函数 所以我想使用建造者设计模式 https en wikipedia org wiki Builder pattern 通过使用lombok https projectlombok org setup maven图书馆 它
  • Tensorflow tf.data.Dataset.cache似乎没有达到预期的效果

    我正在尝试按照以下方法提高我的模型训练性能使用 tf data API 获得更好的性能 https www tensorflow org guide data performance指导方针 然而 我观察到使用的性能 cache 如果与没有