Tensorflowdynamic_rnn参数含义

2024-01-23

我正在努力理解神秘的 RNN 文档。任何有关以下内容的帮助将不胜感激。

tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None)

我正在努力理解这些参数与数学 LSTM 方程和 RNN 定义的关系。单元格展开尺寸在哪里?它是由输入的“max_time”维度定义的吗? batch_size只是为了方便分割长数据还是与小批量SGD有关?输出状态是否跨批次传递?


tf.nn.dynamic_rnn接受一批(具有小批量的含义)不相关的序列。

  • cell是您要使用的实际单元(LSTM、GRU,...)
  • inputs形状为batch_size x max_time x input_size其中 max_time 是最长序列中的步数(但所有序列可以具有相同的长度)
  • sequence_length是一个大小向量batch_size其中每个元素给出批次中每个序列的长度(如果所有序列的大小相同,则将其保留为默认值。该参数定义单元展开尺寸。

隐藏状态处理

处理隐藏状态的通常方法是在隐藏状态之前定义一个初始状态张量dynamic_rnn,例如这样:

hidden_state_in = cell.zero_state(batch_size, tf.float32) 
output, hidden_state_out = tf.nn.dynamic_rnn(cell, 
                                             inputs,
                                             initial_state=hidden_state_in,
                                             ...)

在上面的代码片段中,两个hidden_state_in and hidden_state_out具有相同的形状[batch_size, ...] (实际形状取决于您使用的单元格类型,但重要的是第一个维度是批量大小).

这边走,dynamic_rnn每个序列都有一个初始隐藏状态。它将在每个序列的时间步长之间传递隐藏状态inputs参数本身, and hidden_state_out将包含批次中每个序列的最终输出状态。同一批次的序列之间不会传递任何隐藏状态,而只会在同一序列的时间步之间传递。

什么时候需要手动反馈隐藏状态?

通常,当您进行训练时,每个批次都是无关的,因此您不必在执行训练时反馈隐藏状态session.run(output).

但是,如果您正在测试,并且需要每个时间步骤的输出(即您必须执行session.run()在每个时间步)您将需要使用如下所示的方法来评估并反馈输出隐藏状态:

output, hidden_state = sess.run([output, hidden_state_out],
                                feed_dict={hidden_state_in:hidden_state})

否则tensorflow将只使用默认值cell.zero_state(batch_size, tf.float32)在每个时间步,这相当于在每个时间步重新初始化隐藏状态。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflowdynamic_rnn参数含义 的相关文章

  • 打印出网络架构中每一层的形状

    在 Keras 中 我们可以如下定义网络 有什么办法可以输出每层之后的形状 例如 我想打印出以下形状inputs在定义行之后inputs 然后打印出形状conv1在定义行之后conv1 etc inputs Input 1 img rows
  • Tensorflow DecodeJPEG:预期图像(JPEG、PNG 或 GIF)以“\000\000\000\000\000\000\000\00”开头的格式未知

    我正在循环浏览图像文件夹 这种情况不断发生 tensorflow python framework errors impl InvalidArgumentError 预期的图像 JPEG PNG或GIF 以 000 000 000 000
  • 无法从 DenseVariational 获得合理的结果

    我正在尝试使用以下大小的数据集 正弦曲线 进行回归问题500 首先 我尝试使用 2 个密集层 每个层有 10 个单元 model tf keras Sequential tf keras layers Dense 10 activation
  • 如何保存 Tensorflow.js 模型?

    我想制作一个创建 保存和训练 tensorflow js 模型的用户界面 但我无法在创建模型后保存模型 我什至从tensorflow js文档复制了这段代码 但它不起作用 const model tf sequential layers t
  • 张量流中有哪些资产?

    我正在阅读有关保存和恢复模型的张量流教程 并遇到以下声明 If assets need to be saved and written or copied to disk they can be provided when the firs
  • 如何在对象检测 API Tensorflow 中仅检测人体

    我在用tensorflow对象检测 API 用于检测对象 它在我的 Windows 系统中运行良好 我如何对其进行更改以仅检测提到的对象 例如 我只想检测人类而不是所有对象 根据此中的第 1 条评论answer https stackove
  • 在 Android 上保持 TensorFlow 模型加密

    我搜索了解是否有一种技术可以在 Android 应用程序中保持经过训练的张量流模型 pb 文件 的安全 但没有找到任何有用的东西 我正在发布一个包含我在训练集上构建的张量流模型的应用程序 当我发布该应用程序时 任何人都可以访问该模型并将其用
  • 跨多个 GPU/机器的 TF-Slim 的配置/标志

    我很好奇是否有关于如何使用部署 model deploy py 在多台机器上的多个 GPU 上运行 TF Slim models slim 的示例 该文档非常好 但我缺少一些内容 具体来说 需要为worker device和ps devic
  • 从 [tensorflow 1.00] 中的 softmax 层提取概率

    使用张量流 我有一个 LSTM 分类模型 以 softmax 作为最终节点 这是我的 softmax 层 with tf name scope Softmax as scope with tf variable scope Softmax
  • 使用输入管道时如何替换 feed_dict?

    假设您有一个已与feed dict到目前为止将数据注入到图表中 每隔几个时期 我就会通过将任一数据集的一批数据输入到我的图表中来评估训练和测试损失 现在 出于性能原因 我决定使用输入管道 看看这个虚拟示例 import tensorflow
  • 您必须使用 dtype float(Tensorflow) 为占位符张量“Placeholder”提供值

    import tensorflow as tf import os import sklearn preprocessing import pandas as pd import numpy as np print os getcwd os
  • Tensorboard 和 Dropout 层

    我有一个非常基本的查询 我制作了 4 个几乎相同 差异在于输入形状 的 CNN 并在连接到全连接层的前馈网络时合并了它们 几乎相同的 CNN 的代码 model3 Sequential model3 add Convolution2D 32
  • tf.gfile 在 TensorFlow 中起什么作用?

    我见过人们使用以下几个函数tf gfile例如tf gfile GFile or tf gfile Exists 我有一个想法tf gfile处理文件 但是 我无法找到官方文档来了解它还提供了什么 如果你能帮我的话那就太好了 对于登陆这里的
  • AttributeError:模块“keras.engine”没有属性“Layer”

    当我试图运行时Parking Slot mask rcnn py文件我收到如下错误mrcnn model py文件我该如何解决 gt 2021 06 17 08 25 18 585897 W tensorflow stream execut
  • 使用 flow_from_dataframe y_col 的正确“值”是什么

    我正在用 pandas 读取 csv 文件 并给出存储在中的列名称colname colnames file label Read data from file data pd read csv Hand Annotations 2 csv
  • conv1D 中形状的尺寸

    我尝试过构建一个只有一层的 CNN 但遇到了一些问题 事实上 编译器告诉我 ValueError 检查模型输入时出错 预期的 conv1d 1 input 具有 3 个维度 但得到形状为 569 30 的数组 这是代码 import num
  • 如何在 Tensorflow 中使用预训练的 Word2Vec 模型

    我有一个Word2Vec训练过的模型Gensim 我如何使用它Tensorflow for Word Embeddings 我不想在 Tensorflow 中从头开始训练嵌入 有人可以告诉我如何用一些示例代码来做到这一点吗 假设您有一个字典
  • Tensorflow 到 ONNX 的转换

    我目前正在尝试转换我使用本教程创建的已保存 且正在工作 的 pb 文件 https github com thtrieu darkflow https github com thtrieu darkflow 到 onnx 文件中 我目前正在
  • ubuntu 20.04 上无法获取卷积算法错误~tensorflow-gpu

    我有一个 NVIDIA 2070 RTX GPU 我的操作系统是 Ubuntu20 04 我已经使用 conda 安装了tensorflow gpu 包 我有not安装了 CUDA toolkit 我相信它还会安装 CUDA toolkit
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为

随机推荐