Tensorflow中GRU单元的解释?

2024-04-08

以下是 Tensorflow 的代码GRUCell单元显示了当先前的隐藏状态与序列中的当前输入一起提供时获得更新的隐藏状态的典型操作。

  def __call__(self, inputs, state, scope=None):
    """Gated recurrent unit (GRU) with nunits cells."""
    with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
      with vs.variable_scope("Gates"):  # Reset gate and update gate.
        # We start with bias of 1.0 to not reset and not update.
        r, u = array_ops.split(1, 2, _linear([inputs, state],
                                             2 * self._num_units, True, 1.0))
        r, u = sigmoid(r), sigmoid(u)
      with vs.variable_scope("Candidate"):
        c = self._activation(_linear([inputs, r * state],
                                     self._num_units, True))
      new_h = u * state + (1 - u) * c
return new_h, new_h

但我没有看到任何weights and biases这里。 例如我的理解是r and u需要将权重和偏差与当前输入和/或隐藏状态相乘以获​​得更新的隐藏状态。

我写了一个gru单元如下:

def gru_unit(previous_hidden_state, x):
    r  = tf.sigmoid(tf.matmul(x, Wr) + br)
    z  = tf.sigmoid(tf.matmul(x, Wz) + bz)
    h_ = tf.tanh(tf.matmul(x, Wx) + tf.matmul(previous_hidden_state, Wh) * r)
    current_hidden_state = tf.mul((1 - z), h_) + tf.mul(previous_hidden_state, z)
    return current_hidden_state

这里我明确地使用了权重Wx, Wr, Wz, Wh和偏见br, bh, bz等来获取更新的隐藏状态。这些权重和偏差是训练后学习/调整的。

如何利用 Tensorflow 的内置功能GRUCell达到与上面相同的结果?


它们在那里,您只是在代码中看不到它们,因为 _线性 函数添加了权重和偏差。

r, u = array_ops.split(1, 2, _linear([inputs, state],
                                             2 * self._num_units, True, 1.0))

...

def _linear(args, output_size, bias, bias_start=0.0, scope=None):
  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

  Args:
    args: a 2D Tensor or a list of 2D, batch x n, Tensors.
    output_size: int, second dimension of W[i].
    bias: boolean, whether to add a bias term or not.
    bias_start: starting value to initialize the bias; 0 by default.
    scope: VariableScope for the created subgraph; defaults to "Linear".

  Returns:
    A 2D Tensor with shape [batch x output_size] equal to
    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.
  """
  if args is None or (nest.is_sequence(args) and not args):
    raise ValueError("`args` must be specified")
  if not nest.is_sequence(args):
    args = [args]

  # Calculate the total size of arguments on dimension 1.
  total_arg_size = 0
  shapes = [a.get_shape().as_list() for a in args]
  for shape in shapes:
    if len(shape) != 2:
      raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
    if not shape[1]:
      raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
    else:
      total_arg_size += shape[1]

  # Now the computation.
  with vs.variable_scope(scope or "Linear"):
    matrix = vs.get_variable("Matrix", [total_arg_size, output_size])
    if len(args) == 1:
      res = math_ops.matmul(args[0], matrix)
    else:
      res = math_ops.matmul(array_ops.concat(1, args), matrix)
    if not bias:
      return res
    bias_term = vs.get_variable(
        "Bias", [output_size],
        initializer=init_ops.constant_initializer(bias_start))
  return res + bias_term
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow中GRU单元的解释? 的相关文章

  • 从图中删除节点或重置整个默认图

    使用默认全局图时 是否可以在添加节点后将其删除 或者将默认图重置为空 当我在 IPython 中交互地使用 TF 时 我发现自己必须反复重新启动内核 如果可能的话 我希望能够更轻松地尝试图表 更新 11 2 2016 tf reset de
  • 跨多个 GPU/机器的 TF-Slim 的配置/标志

    我很好奇是否有关于如何使用部署 model deploy py 在多台机器上的多个 GPU 上运行 TF Slim models slim 的示例 该文档非常好 但我缺少一些内容 具体来说 需要为worker device和ps devic
  • Keras:加载多个模型并在不同线程中进行预测

    我正在使用带有张量流核心的 Keras 我想在构造函数中加载 2 个不同的模型 然后在不同的线程中进行预测 根据请求 我尝试在张量流图上下文中加载这些模型 但它不起作用 我的代码 from keras models import load
  • 什么是 ANN 中的纪元以及它如何转换为 MATLAB 中的代码?

    我试图理解 并可视化 训练人工神经网络的时代到底是什么 我们有一个包含约 7000 个产品的训练集 其中有 10 个特征 输入 这些产品必须根据这 10 个输入分为 7 个类别 我们的 ANN 有 10 个输入 这些输入进入由 10 个神经
  • 从 [tensorflow 1.00] 中的 softmax 层提取概率

    使用张量流 我有一个 LSTM 分类模型 以 softmax 作为最终节点 这是我的 softmax 层 with tf name scope Softmax as scope with tf variable scope Softmax
  • 具有高级计算功能的 Keras 自定义层

    我想写一些自定义的Keras分层并在层中进行一些高级计算 例如使用 Numpy Scikit OpenCV 我知道有一些数学函数keras backend可以对张量进行操作 但我需要一些更高级的功能 但是 我不知道如何正确实现这一点 我收到
  • Tensorflow推荐的系统规格?

    我开始在我的 RHEL 6 5 机器上安装 Tensorflow 但事实证明 Tensorflow 需要 glibc gt 2 17 而 rhel 6 5 上默认的 glibc 是 2 12 我想知道是否有人可以帮助我了解张量流的最低 推荐
  • 您必须使用 dtype float(Tensorflow) 为占位符张量“Placeholder”提供值

    import tensorflow as tf import os import sklearn preprocessing import pandas as pd import numpy as np print os getcwd os
  • 导入tensorflow时,出现以下错误:没有名为“numpy.core._multiarray_umath”的模块

    我已经安装了 Ancaconda3 和 Tensorflow 当我尝试在 python shell 中导入 Tensorflow 时 收到以下错误 ModuleNotFoundError 没有名为 numpy core multiarray
  • 用于测试张量流安装的速度基准

    我怀疑我的 GPU 机器上是否正确配置了张量流 因为在我精美的 GPU 机器上训练一个简单的线性回归模型 批量大小 32 1500 个输入特征 150 个输出变量 的每次迭代速度比在笔记本电脑上慢 100 倍 我使用的是 Titan X 配
  • 使用 Keras 时,验证集中未见的类别会出现错误

    我有由数值变量和分类变量组成的数据 分类变量有很多类别 因此我使用嵌入来表示这些类别 我的模型是一个简单的神经网络 我知道当你定义嵌入层时你需要通过input dim number of categories 1为了解释训练中看不见的类别
  • 神经网络的层和神经元

    我想更多地了解神经网络 我正在开发一个 C 程序来制作神经网络 但我坚持使用反向传播算法 很抱歉没有提供一些工作代码 我知道有很多库可以用多种语言创建神经网络 但我更喜欢自己制作一个 关键是我不知道要实现特定目标 例如模式识别或函数近似或其
  • 如何在 Tensorflow 中使用预训练的 Word2Vec 模型

    我有一个Word2Vec训练过的模型Gensim 我如何使用它Tensorflow for Word Embeddings 我不想在 Tensorflow 中从头开始训练嵌入 有人可以告诉我如何用一些示例代码来做到这一点吗 假设您有一个字典
  • Tensorflow seq2seq 获取序列隐藏状态

    我不久前才开始研究tensorflow 我正在研究 seq2seq 模型 并以某种方式让教程起作用 但我一直坚持获取每个句子的状态 据我了解 seq2seq 模型采用输入序列并通过 RNN 为序列生成隐藏状态 随后 模型使用序列的隐藏状态来
  • 交换keras中的张量轴

    我想将图像批次的张量轴从 batch size row col ch 交换为 批次大小 通道 行 列 在 numpy 中 这可以通过以下方式完成 X batch np moveaxis X batch 3 1 我该如何在 Keras 中做到
  • 在tensorflow.js中对张量进行分区、屏蔽或过滤

    我有 2 个相同长度的张量 data and groupIds 我想分开data通过相应的值分成几组groupId 例如 const data tf tensor 1 2 3 4 5 const groupIds tf tensor 0 1
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • 图书馆神经实验室培训纽夫

    我对 python 和 Neurolab 的使用还很陌生 我在前馈神经网络的训练方面遇到了问题 我已经构建了如下网络 net nl net newff 1 1 64 60 1 net init testerr net train Input
  • 如何从张量流数据集迭代器返回同一批次两次?

    我正在转换一些旧代码以使用数据集 API 此代码使用feed dict将一批数据送入列车运行 实际上是三次 然后重新计算损失以供显示使用同一批 所以我需要一个迭代器来返回完全相同的批次两次 或多次 不幸的是 我似乎找不到一种使用张量流数据集
  • 具有动态 num_partitions 的动态分区

    变量num partitions在方法中tf dynamic partition不是一个Tensor 但是一个int 因此 如果事先不知道分区的数量 则无法通过计算唯一值的数量等方式从数据中推断出分区的数量 也无法通过tf placehol

随机推荐

  • 如何在“border-*”属性中使用百分比?

    我有使用 Twitter Bootstrap 3 的代码 nav with right arrow 我使用它创建的border 特性 但是如果我在中使用很长的文本right arrow 它不会扩展 如果我使用百分比 代码将无法工作 示例Js
  • 复制virtualenv文件夹后如何在Cygwin中激活virtualenv

    完整的初学者在这里 尝试构建一个 Flask Web 应用程序 使用 Windows 8 在 Cygwin 中激活我的 python virtualenv 时遇到一些问题 到目前为止我一直在使用 git shell 没有任何问题 我将文件夹
  • React.js:将默认值设置为 prop

    我制作了这个组件来创建一个简单的按钮 class AppButton extends Component setOnClick if this props onClick typeof this props onClick function
  • 在 ASP.NET MVC 3 应用程序中扩展 Windows 身份验证

    经过大量谷歌搜索并阅读了有关如何在 ASP NET 应用程序中管理混合模式身份验证的几种解决方案后 我仍然没有适合我的问题的解决方案 我必须为一堆不同的用户组实现一个 Intranet 应用程序 到目前为止 我一直使用 Windows 身份
  • 无法在 ubuntu 19.04 上安装 libzmq3-dev

    我正在尝试安装libzmq3 dev on 乌班图19 04 使用命令 sudo apt install build essential libsocketcan dev libzmq3 dev 我收到消息 gt Some packages
  • Pentaho Spoon 工具转换顺序

    我正在尝试设计一个 ETL 结构 但我陷入了以下步骤 正如你所看到的 我有 3 个步骤 每个步骤都有一个FK上一步的值 例如TABLE3有一个列外键约束这表明PK值在TABLE2 and TABLE2与 具有相同的关系TABLE1 问题是
  • 如何在我的 Maven 项目中正确包含“org.apache.catalina.filters.SetCharacterEncodingFilter”过滤器?

    我使用 Maven 3 3 和 JBoss 7 1 3 Final Java 6 我想在我的 Web 应用程序中包含一个过滤器 以便所有传入请求数据都将编码为 UTF 8 所以我将其添加到我的 web xml 文件中
  • Powershell CheckedListBox 检查是否在字符串/数组中

    我已经开始学习 Powershell 但在花了几个小时解决一个问题后陷入困境 我可以找到除 Powershell 之外的多种语言的解决方案 我需要对 CheckedListBox 中的每个项目进行检查 该项目与名为的分号分隔字符串中的任何值
  • WPF 中 WinForms TextBox.Validating 事件的等效项

    在 WinForms 中 我可以处理 Validated 事件 以便在用户更改 TextBox 中的文本后执行某些操作 与 TextChanged 不同 Validated 不会在每次字符更改时触发 它仅在用户完成后触发 WPF 中是否有任
  • 我到底必须在 viewDidUnload 中做什么?

    我倾向于在 dealloc 中释放我的东西 现在 iPhone OS 3 0 引入了这个有趣的 viewDidUnload 方法 他们说 释放所有保留的子视图 主要视图 例如自我我的出口 零 因此 当视图控制器的视图从内存中启动时 view
  • Pandas - 按一列分组,按另一列排序,从第三列获取值

    我想采用 pandas 数据框 按一列对其进行分组 按另一列对其进行排序 并从第三列中获取第一个元素并填充原始数据框 这是我原来的 df 我将按 col 1 分组 按 col 2 升序 排序 并从 col 3 中取出第一个元素并用结果填充
  • 对角线穿过视图

    根据某些条件 我必须对角剪切列表单元格 为此 我使用以下代码制作了对角线可绘制图像 对角线 xml
  • 沿多边形边界随机采样点

    I am trying to randomly sample points on a polygon boundary made of arbitrary number of points The polygon consist of a
  • C++中的默认参数

    考虑以下 int foo int x int z 0 int foo int x int y int z 0 如果我像这样调用这个函数 foo 1 2 编译器如何知道使用哪一个 它不会 因此这个例子不会编译干净 它会给你一个编译错误 它会给
  • Cardview 涟漪效应不起作用

    最小 SDK 为 21 当我单击回收器适配器中的卡片视图时 不会发生连锁反应 只会转到下一个屏幕 recyclerview 位于片段内
  • JDBC 无法加载数据源的工厂类

    我已经遇到这个问题好几天了 但没有设法解决它 我使用的是 tomcat 7 0 我完全无法连接 mysql 数据库 我正在编写的应用程序是一个使用eclipse IDE的jsp动态网站 TomCat 7 启动时出现此错误 WARNING F
  • 为什么 Z3 在这个简单的输入上返回“未知”?

    这是输入 set option auto config false set option mbqi false declare sort T6 declare sort T7 declare fun set23 T7 T7 Bool ass
  • 在 Aptana Studio 3 中禁用 CSS 验证

    有人知道如何使用 Aptana Studio 3 禁用 CSS 验证吗 在版本 3 0 4 中 即使完全完成后 警告仍然存在禁用 W3C CSS 验证器 https stackoverflow com questions 6652793 h
  • 在最近的 JVM 中,不可见引用仍然是一个问题吗?

    我正在读书Java 平台性能 http java sun com docs books performance 1st edition html JPAppGC fm html 遗憾的是 自从我最初提出这个问题以来 该链接似乎已经从互联网上
  • Tensorflow中GRU单元的解释?

    以下是 Tensorflow 的代码GRUCell单元显示了当先前的隐藏状态与序列中的当前输入一起提供时获得更新的隐藏状态的典型操作 def call self inputs state scope None Gated recurrent