在 while_loop 的上下文中使用 TensorArray 来累加值

2024-01-02

下面是 Tensorflow RNN Cell 的实现,旨在模拟本文中 Alex Graves 的算法 ACT:http://arxiv.org/abs/1603.08983 http://arxiv.org/abs/1603.08983.

在通过 rnn.rnn 调用的序列中的单个时间步(使用静态 sequence_length 参数,因此 rnn 是动态展开的 - 我使用固定批量大小 20),我们递归调用 ACTStep,生成 size(1,200) 的输出,其中RNN 单元的隐藏维度为 200,批量大小为 1。

使用 Tensorflow 中的 while 循环,我们进行迭代,直到累积的停止概率足够高。所有这些工作都相当顺利,但我在 while 循环内累积状态、概率和输出时遇到问题,我们需要这样做才能创建这些的加权组合作为最终的单元输出/状态。

我尝试使用一个简单的列表,如下所示,但是当编译图时,由于输出不在同一帧中,因此会失败(是否可以使用 control_flow_ops 中的“switch”函数将张量转发到它们是必需的,即在我们返回值之前的 add_n 函数?)。我也尝试过使用 TensorArray 结构,但我发现这很难使用,因为它似乎破坏了形状信息,并且手动替换它不起作用。我还没有找到太多关于 TensorArrays 的文档,我想它们可能主要供内部 TF 使用。

任何有关如何正确累积 ACTStep 生成的变量的建议将不胜感激。

class ACTCell(RNNCell):
"""An RNN cell implementing Graves' Adaptive Computation time algorithm"""
def __init__(self, num_units, cell, epsilon, max_computation):

    self.one_minus_eps = tf.constant(1.0 - epsilon)
    self._num_units = num_units
    self.cell = cell
    self.N = tf.constant(max_computation)
@property
def input_size(self):
    return self._num_units
@property
def output_size(self):
    return self._num_units
@property
def state_size(self):
    return self._num_units

def __call__(self, inputs, state, scope=None):

    with vs.variable_scope(scope or type(self).__name__):

        # define within cell constants/ counters used to control while loop
        prob = tf.get_variable("prob", [], tf.float32,tf.constant_initializer(0.0))
        counter = tf.get_variable("counter", [],tf.float32,tf.constant_initializer(0.0))
        tf.assign(prob,0.0)
        tf.assign(counter, 0.0)

        # the predicate for stopping the while loop. Tensorflow demands that we have
        # all of the variables used in the while loop in the predicate.
        pred = lambda prob,counter,state,input,\
                      acc_state,acc_output,acc_probs:\
            tf.logical_and(tf.less(prob,self.one_minus_eps), tf.less(counter,self.N))

        acc_probs = []
        acc_outputs = []
        acc_states = []


        _,iterations,_,_,acc_states,acc_output,acc_probs = \
        control_flow_ops.while_loop(pred,
        self.ACTStep,
        [prob,counter,state,input,acc_states,acc_outputs,acc_probs])

    # TODO:fix last part of this, need to use the remainder.
    # TODO: find a way to accumulate the regulariser

    # here we take a weighted combination of the states and outputs 
    # to use as the actual output and state which is passed to the next timestep.

    next_state = tf.add_n([tf.mul(x,y) for x,y in zip(acc_probs,acc_states)])
    output = tf.add_n([tf.mul(x,y) for x,y in zip(acc_probs,acc_outputs)])


    return output, next_state

def ACTStep(self,prob,counter,state,input, acc_states,acc_outputs,acc_probs):

    output, new_state = rnn.rnn(self.cell, [input], state, scope=type(self.cell).__name__)

    prob_w = tf.get_variable("prob_w", [self.cell.input_size,1])
    prob_b = tf.get_variable("prob_b", [1])
    p = tf.nn.sigmoid(tf.matmul(prob_w,new_state) + prob_b)

    acc_states.append(new_state)
    acc_outputs.append(output)
    acc_probs.append(p)

    return [tf.add(prob,p),tf.add(counter,1.0),new_state, input,acc_states,acc_outputs,acc_probs]

我将在这个回复前言,这不是一个完整的解决方案,而是一些关于如何改进你的细胞的评论。

首先,在 ACTStep 函数中,您调用rnn.rnn对于一个时间步(定义为[input]。如果您正在执行单个时间步,那么简单地使用实际的时间步可能会更有效self.cell通话功能。您将看到张量流中使用相同的机制细胞包装器 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py#L708

你提到你已经尝试过使用TensorArrays。您是否正确打包和解包张量数组?这里有一个repo https://github.com/ofirnachum/sequence_gan/blob/master/model.py你会在哪里找到model.py张量数组已正确打包和解包。

您还问是否有一个功能control_flow_ops这将需要累积所有张量。我想你正在寻找tf.control_dependencies https://www.tensorflow.org/versions/r0.9/api_docs/python/framework.html#control_dependencies

您可以在 control_dependicies 中列出所有输出张量操作,这将需要张量流来计算该点之前的所有张量。

另外,它看起来像你的counter变量是可训练的。您确定要这样吗?如果您将计数器加一,则可能不会产生正确的结果。另一方面,您可以故意使其保持可训练性,以便在思考成本函数的最后对其进行区分。

另外我相信 Remainder 函数应该在您的脚本中:

remainder = 1.0 - tf.add_n(acc_probs[:-1])
#note that there is a -1 in the list as you do not want to grab the last probability

这是我编辑的代码版本:

class ACTCell(RNNCell):
    """An RNN cell implementing Graves' Adaptive Computation time algorithm
    Notes: https://www.evernote.com/shard/s189/sh/fd165646-b630-48b7-844c-86ad2f07fcda/c9ab960af967ef847097f21d94b0bff7

    """
    def __init__(self, num_units, cell, max_computation = 5.0, epsilon = 0.01):

        self.one_minus_eps = tf.constant(1.0 - epsilon) #episolon is 0.01 as found in the paper
        self._num_units = num_units
        self.cell = cell
        self.N = tf.constant(max_computation)

    @property
    def input_size(self):
        return self._num_units
    @property
    def output_size(self):
        return self._num_units
    @property
    def state_size(self):
        return self._num_units

    def __call__(self, inputs, state, scope=None):

        with vs.variable_scope(scope or type(self).__name__):

            # define within cell constants/ counters used to control while loop
            prob = tf.constant(0.0, shape = [batch_size]) 
            counter = tf.constant(0.0, shape = [batch_size])

            # the predicate for stopping the while loop. Tensorflow demands that we have
            # all of the variables used in the while loop in the predicate.
            pred = lambda prob,counter,state,input,acc_states,acc_output,acc_probs:\
                tf.logical_and(tf.less(prob,self.one_minus_eps), tf.less(counter,self.N))

            acc_probs, acc_outputs, acc_states  = [], [], []

            _,iterations,_,_,acc_states,acc_output,acc_probs = \
            control_flow_ops.while_loop(
            pred,
            self.ACTStep, #looks like he purposely makes the while loop here
            [prob, counter, state, input, acc_states, acc_outputs, acc_probs])

        '''mean-field updates for states and outputs'''
        next_state = tf.add_n([x*y for x,y in zip(acc_probs,acc_states)])
        output = tf.add_n([x*y for x,y in zip(acc_probs,acc_outputs)])

        remainder = 1.0 - tf.add_n(acc_probs[:-1]) #you take the last off to avoid a negative ponder cost #the problem here is we need to take the sum of all the remainders
        tf.add_to_collection("ACT_remainder", remainder) #if this doesnt work then you can do self.list based upon timesteps
        tf.add_to_collection("ACT_iterations", iterations)
        return output, next_state 

    def ACTStep(self,prob, counter, state, input, acc_states, acc_outputs, acc_probs):

        '''run rnn once'''
        output, new_state = rnn.rnn(self.cell, [input], state, scope=type(self.cell).__name__)

        prob_w = tf.get_variable("prob_w", [self.cell.input_size,1]) 
        prob_b = tf.get_variable("prob_b", [1])
        halting_probability = tf.nn.sigmoid(tf.matmul(prob_w,new_state) + prob_b) 


        acc_states.append(new_state)
        acc_outputs.append(output)
        acc_probs.append(halting_probability) 

        return [p + prob, counter + 1.0, new_state, input,acc_states,acc_outputs,acc_probs]


    def PonderCostFunction(self, time_penalty = 0.01):
        '''
        note: ponder is completely different than probability and ponder = roe

        the ponder cost function prohibits the rnn from cycling endlessly on each timestep when not much is needed
        '''
        n_iterations = tf.get_collection_ref("ACT_iterations")
        remainder = tf.get_collection_ref("ACT_remainder")
        return tf.reduce_sum(n_iterations + remainder) #completely different from probability

这是一篇我自己一直在努力实现的复杂论文。我不介意与您合作在 Tensorflow 中完成它。如果您有兴趣,请在 Skype 上添加我的 LeavesBreathe,我们可以从那里开始。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 while_loop 的上下文中使用 TensorArray 来累加值 的相关文章

随机推荐

  • 如何以 OOP 风格使用 TensorFlow?

    具体来说 当使用 TensorFlow 以 OOP 风格构建模型时 我应该在哪里构建图 我应该在哪里启动会话来运行图表 此案例的最佳实践是什么 In TensorFlow 力学 101 https www tensorflow org tu
  • ES6 fetch 函数返回未定义[重复]

    这个问题在这里已经有答案了 我有以下代码 function fetchDemo var result fetch countriesUrl then function response return response json then f
  • 画布未在reactjs中渲染

    我想在我正在开发的网站上添加画布 但我似乎可以理解为什么画布没有显示 可能是什么问题 以下是我尝试过的 当我将鼠标悬停在标题上时 它显示画布正在更新 但屏幕上没有显示任何内容 画布 jsx export class Canvas exten
  • 在 R 中按模式重命名列

    我想按特定模式重命名数据框中的所有列 我的输入 Log NE122 Log NE244 Log NE144 0 33 0 98 1 0 我的预期输出 NE122 NE244 NE144 0 33 0 98 1 0 Cheers 您可以使用正
  • 在 Visual Studio 中开发 Azure Function 时存储帐户无效

    我正在使用 C 在 Visual Studio 中开发 Azure Function 我在位于代理后面的开发机器上本地运行它 但是不断收到此错误 Exception binding parameter Invalid storage acc
  • 打字稿路径无法解析

    Here https github com oleersoy typescript pathsGithub MCVE 显示了一个问题 npm run compile显示错误 我正在尝试这样做 import Todo from test 但这
  • 检测用户是否在颤动上按下 home / tab 的代码?

    是否有任何代码可以检测用户是否按下了 home tab 我想让我的音乐在按下时暂停 通过添加观察者来跟踪生命周期事件WidgetsBinding然后在应用程序暂停时暂停音乐 你可以看看this https github com flutte
  • 核心数据executeFetchRequest抛出NSGenericException(枚举时集合发生了变化)

    我正在使用 Core Data 开发 iPhone 应用程序 所有用户数据应与我们的服务器同步 为此 我创建了 NSOperation 的子类 它从我们的 Web 服务加载新数据并创建相应的托管对象 为了维护它们之间的关系 每个对象都使用远
  • 哪个是最好的 git 托管软件? - Gitolite vs. Gitlab vs. Gitorius [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我正在寻找适合多个用户的 git 托管环境 因此我搜索了之间的比较Gitolite Gitlab and Gitorius 但我没有得到任何有用
  • YAML:YAML 中的字符串需要引号吗?

    我正在尝试编写一个用于 Rails 项目国际化的 YAML 字典 不过我有点困惑 因为在某些文件中我看到字符串用双引号引起来 而在某些文件中则没有 需要考虑的几点 示例1 https github com plataformatec dev
  • Powershell:使用字符串匹配条件将单个文件拆分为多个文件

    我有一个包含 1GB 数据的文件 该数据实际上是数十个或数千个单独的迷你文件 我需要提取每个单独的文件并将它们放入自己单独的不同文件中 所以本质上 我需要从单个文件变成 30K 单独的文件 这是 我的文件 的示例 文件名 1 版本 1 32
  • CRUDRespository 中的更新或 SaveorUpdate,是否有任何可用选项

    我正在尝试使用 My Entity bean 执行 CRUD 操作 CRUDRepository提供标准方法find delete and save但没有可用的通用方法 例如saveOrUpdate Entity entity 进而调用Hi
  • 如何将json对象显示为html?

    我的 Json 对象是这样的 attributes Code SGL Total 19421340 27 DayPrice Date 2016 07 22 Rate 4900439 85 Date 2016 07 23 Rate 48451
  • 绕过 Google 电子表格中的循环引用

    我有一个谷歌文档电子表格 有两列 A 和 B B 的值只是 A 中不同格式的值 并且我在 B 列中有一个公式可以进行转换 有时我没有 A 格式的值 但有 B 格式的值 我想通过在 A 列中添加进行反向转换的公式来自动获取 A 列中 A 格式
  • 如何在 vue.js 构建上重命名 index.html?

    我想重命名index html产生于npm run build 我在 webpack 配置中找不到任何内容 我还创建了一个vue config js此处描述 https github com vuejs vue cli tree dev d
  • React Redux 工具包:类型错误:无法读取未定义的属性“值”

    在我的项目中 我为 2 个不同的状态场景实现了 React Redux 工具包 并且它们工作得很好 现在我需要为 Redux 实现第三个状态场景 因此我遵循与前 2 个状态场景相同的模式 灵感来自 https react redux js
  • 为什么我的 Django 表单没有引发任何错误?

    我有一个简单的表单 每当用户在表单上做错事时 我想在 Django 上引发验证错误 问题是我设置了表单验证 但是当提交表单时使用错误的值时 它会通过 我想知道为什么会发生这种情况以及如何避免这种情况 这是 html 形式
  • 如何检查浏览器是否支持flash?

    我有一个 Flash 横幅 如果客户端浏览器没有启用 Flash 我想用静态图像替换它 我想知道我是否可以用 php 做到这一点 或者是否有人知道一个好方法 Thanks 允许 您的 Flash 影片 降级
  • 使用 Flask-limiter 限制端点速率

    我知道并且爱flask limiter来自较旧的项目 现在我想用它在我的flask restplus为基础的项目 我的最终解决方案将使我能够在每个方法级别上进行速率限制 因此 post 方法的费率与 get 方法的费率不同 但如果我可以定义
  • 在 while_loop 的上下文中使用 TensorArray 来累加值

    下面是 Tensorflow RNN Cell 的实现 旨在模拟本文中 Alex Graves 的算法 ACT http arxiv org abs 1603 08983 http arxiv org abs 1603 08983 在通过