与 Tensorflow 中的常规 LSTMCell 相比,使用 CudnnLSTM 训练时的结果不同

2023-11-24

我正在 Python 中使用 Tensorflow 训练 LSTM 网络,并希望切换到 tf.contrib.cudnn_rnn.CudnnLSTM 以加快训练速度。我所做的被替换

cells = tf.nn.rnn_cell.LSTMCell(self.num_hidden) 
initial_state = cells.zero_state(self.batch_size, tf.float32)
rnn_outputs, _ = tf.nn.dynamic_rnn(cells, my_inputs, initial_state = initial_state)

with

lstm = tf.contrib.cudnn_rnn.CudnnLSTM(1, self.num_hidden)
rnn_outputs, _ = lstm(my_inputs)

我的训练速度显着提升(超过 10 倍),但同时我的性能指标却下降了。使用 LSTMCell 时,二元分类的 AUC 为 0.741,使用 CudnnLSTM 时,二元分类的 AUC 为 0.705。我想知道我是否做错了什么,或者这两者之间的实现存在差异,这就是如何在继续使用 CudnnLSTM 的同时恢复性能的情况。

训练数据集有 15,337 个不同长度的序列(最多几百个元素),这些序列用零填充,以便在每个批次中具有相同的长度。所有代码都是相同的,包括 TF 数据集 API 管道和所有评估指标。我运行了每个版本几次,并且在所有情况下它都收敛于这些值。

此外,我几乎没有可以插入完全相同模型的数据集,并且所有这些数据集都存在问题。

In the cudnn_rnn 的张量流代码我找到一句话说:

Cudnn LSTM 和 GRU 在数学上不同于它们的 tf 同行。

但没有解释这些差异到底是什么......


它似乎tf.contrib.cudnn_rnn.CudnnLSTM是时间主要的,所以应该提供形状的顺序(seq_len, batch_size, embedding_size)代替(batch_size, seq_len, embedding_size),所以你必须转置它(我认为,当涉及到混乱的 Tensorflow 文档时无法确定,但你可能想测试一下。如果你想检查它,请参阅下面的链接)。

有关该主题的更多信息here(其中有另一个链接指向数学差异),除了一件事似乎是错误的:不仅 GRU 是时间主要的,LSTM 也是(如这个问题).

我会建议against using tf.contrib,因为它更加混乱(最终将被排除在 Tensorflow 2.0 版本之外)并坚持keras如果可能的话(因为它将是即将到来的主要前端张量流2.0) or tf.nn,因为它将成为tf.EstimatorAPI(尽管在我看来它的可读性要差得多)。

...或者考虑使用 PyTorch 来省去麻烦,至少在文档中提供了输入形状(及其含义)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

与 Tensorflow 中的常规 LSTMCell 相比,使用 CudnnLSTM 训练时的结果不同 的相关文章

随机推荐

  • 如何使 JXTreeTable 对其顶部元素进行排序

    我知道 我已经查看了来源 JXTreeTable 上的排序已被禁用 但是 我希望允许仅根据根节点的直接子节点的值对所有列进行排序 假设我有这样的结构 Name Date File UID Root Mr X 1996 10 22 AE123
  • rspec 和 Shoulda - 互补还是替代?

    我已经使用shoulda有一段时间了 并且我已经阅读并使用了rspec 我没有做过深入的比较和对比 但在我看来 两者之间有一些重叠 但它们并不是一对一的替代 我正在考虑使用 rspec 在我的 Rails 系统中编写一些单元测试 而不替换用
  • Django ORM:覆盖子类中字段的 related_name

    我得到这个异常 django core exceptions FieldError 类 SpecialPlugin 中的本地字段 ticket 与基类 BasePlugin 中名称相似的字段发生冲突 这是我的模型 class BasePlu
  • 点击屏幕顶部状态栏时 UITableView 滚动到顶部

    我插入了一个UITableView在另一个里面UIViewController的观点 但是当我点击屏幕顶部的状态栏时 表视图不会滚动到顶部 这是 iOS 应用程序中的预期行为 我试过 self tableView setScrollsToT
  • 更改实例变量

    我有这个代码 class Yes def init self self a 1 def yes self if self a 1 print Yes else print No but yes class No Yes def no sel
  • 控制android状态栏图标

    我正在尝试对状态栏中图标的状态进行一些控制 我希望能够执行以下操作 保留图标 在状态栏中可见 只要 当应用程序运行时 即使用户选择清除状态栏 清除状态栏中的图标 如果应用程序退出 即使 特别是 它被杀死 我意识到当应用程序显式退出时我可以将
  • 将 Relay 与 React-Native 结合使用时的条件片段或嵌入式根容器

    我有relay与 一起工作react native 但我对如何最好地利用中继路由和根容器感到困惑 特别是在使用Navigator呈现多条路线 参加以下课程 var Nav React createClass renderScene rout
  • 测试用例和断言语句

    代码在这个问题让我思考 assert value gt 0 Precondition if value gt 0 Doit 我从不写 if 语句 断言就足够了 你全部can做 早早崩溃 经常崩溃 代码完成 states 断言语句使应用程序正
  • 以下位操作的优化机会?

    您认为 haswon 函数还有优化的空间吗 见下文 我认识到将参数类型从 int64 to unsigned int64使该功能比我想象的更快 也许还有优化的机会 更详细地说 我正在写一个连接四个游戏 最近我使用了Profiler很困并认识
  • 如何在 Visual Studio 2008 中自定义复制/粘贴行为?

    如何在 Visual Studio 2008 中自定义复制 粘贴行为 例如我创建一个新的 div div 然后将其复制并粘贴到同一个文件中 VisualStudio 粘贴 div div 而不是我复制的原文 更令人沮丧的是 当我尝试复制一组
  • 通过 Javascript 访问 Google Apps 公共电子表格

    花了很多时间看这个 似乎有关访问 Google apps 电子表格的少量信息维护得不是很好 今年的 Google IO 上宣布了增强的 Google apps 脚本 包括 UI 元素 这让我想到创建一个基于 Google 电子表格中的数据的
  • 在 MVC 操作中将 SSRS 报告导出为 PDF

    是的 我想将 SSRS 报告导出为 PDF 并从我的操作中返回它 我没有任何报告查看器 请建议我如何实现这一目标 到目前为止我已经做到了 public void SqlServerReport NetworkCredential nwc n
  • 指针和数组混淆的 K&R Qsort 示例

    我发现很难理解下面的代码片段 我理解所显示的指向函数风格的指针 但我发现混乱之处在于指示的行中 void qsort void v int left int right int comp void void int i last void
  • 带有数字填充的 CSS 计数器 [重复]

    这个问题在这里已经有答案了 可以垫吗counter数字取决于其价值 div counter reset ruler div gt span display block line height 1rem div gt span before
  • 在VS2022中的wsl2中调试控制台时读取输入

    我在 Visual Studio 2022 中创建了一个控制台应用程序 只有两行 WriteLine 和 ReadLine 在 Windows 上调试它时 会打开一个控制台 显示输出并等待输入 但是 如果我将其切换到 WSL 调试 我会在
  • Java:ArrayList如何管理内存

    在我的数据结构课程中 我们研究了 Java ArrayList 类 以及当用户添加更多元素时它如何增长底层数组 这是可以理解的 但是 我无法弄清楚当从列表中删除大量元素时 此类到底如何释放内存 查看源码 删除元素的方法有3种 public
  • 如何指定退出或中止的方法

    我有一个从 CLI 触发的方法 该方法具有一些显式退出或中止的逻辑路径 我发现 在为此方法编写规范时 RSpec 将其标记为失败 因为退出是异常 这是一个简单的例子 def cli method if condition puts Ever
  • 如何使用 sox 合并多个音频文件

    我使用以下命令通过 sox 将两个音频文件合并为一个 sox end mp3 p pad 6 0 sox m start mp3 output mp3 我想知道如何仅使用一个命令来合并 3 或 4 个音频文件 而不是使用 output mp
  • #java.lang.NoClassDefFoundError: org/apache/commons/digester/Digester

    我正进入 状态java lang NoClassDefFoundError org apache commons digester Digester错误 我被这个错误困扰了一个多月 我已经尝试了所有可用的 Digester 版本 并且还检查
  • 与 Tensorflow 中的常规 LSTMCell 相比,使用 CudnnLSTM 训练时的结果不同

    我正在 Python 中使用 Tensorflow 训练 LSTM 网络 并希望切换到 tf contrib cudnn rnn CudnnLSTM 以加快训练速度 我所做的被替换 cells tf nn rnn cell LSTMCell