tf.assign 到变量切片在 tf.while_loop 中不起作用

2023-12-04

下面的代码有什么问题?这tf.assign当应用于 a 的切片时,op 工作得很好tf.Variable如果它发生在循环之外。但是,在这种情况下,它会给出以下错误。

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i, a):
    return i < n 

def body(i, a):
    tf.assign(a[i], a[i-1] + a[i-2])
    return i + 1, a

i, b = tf.while_loop(cond, body, [2, a]) 

结果是:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3210, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2942, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2879, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/hrbigelow/ai/lb-wavenet/while_var_test.py", line 11, in body
    tf.assign(a[i], a[i-1] + a[i-2])
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 220, in assign
    return ref.assign(value, name=name)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 697, in assign
    raise ValueError("Sliced assignment is only supported for variables")
ValueError: Sliced assignment is only supported for variables

您的变量不是循环内部运行的操作的输出,它是位于循环外部的外部实体。因此您不必将其作为参数提供。

此外,您需要强制进行更新,例如使用tf.control_dependencies in body.

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i):
    return i < n 

def body(i):
    op = tf.assign(a[i], a[i-1] + a[i-2])
    with tf.control_dependencies([op]):
      return i + 1

i = tf.while_loop(cond, body, [2])

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
i.eval()
print(a.eval())
# [ 1  1  2  3  5  8 13 21 34 55 89]

也许您可能需要谨慎并设置parallel_iterations=1强制循环按顺序运行。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tf.assign 到变量切片在 tf.while_loop 中不起作用 的相关文章

  • 将 saxon 与 python 结合使用

    我需要使用 python 处理 XSLT 目前我正在使用仅支持 XSLT 1 的 lxml 现在我需要处理 XSLT 2 有没有办法将 saxon XSLT 处理器与 python 一起使用 有两种可能的方法 设置一个 HTTP 服务 接受
  • PyUSB 1.0:NotImplementedError:此平台不支持或未实现操作

    我刚刚开始使用 pyusb 基本上我正在玩示例代码here https github com walac pyusb blob master docs tutorial rst 我使用的是 Windows 7 64 位 并从以下地址下载 z
  • SQLALchemy .query:类“Car”的未解析属性引用“query”

    我有一个这里已经提到的问题https youtrack jetbrains com issue PY 44557 https youtrack jetbrains com issue PY 44557 但我还没有找到解决方案 我使用 Pyt
  • 基于代理的模拟:性能问题:Python vs NetLogo & Repast

    我正在 Python 3 中复制一小段 Sugarscape 代理模拟模型 我发现我的代码的性能比 NetLogo 慢约 3 倍 这可能是我的代码的问题 还是Python的固有限制 显然 这只是代码的一个片段 但 Python 却花费了三分
  • Python pickle:腌制对象不等于源对象

    我认为这是预期的行为 但想检查一下 也许找出原因 因为我所做的研究结果是空白 我有一个函数可以提取数据 创建自定义类的新实例 然后将其附加到列表中 该类仅包含变量 然后 我使用协议 2 作为二进制文件将该列表腌制到文件中 稍后我重新运行脚本
  • AWS EMR Spark Python 日志记录

    我正在 AWS EMR 上运行一个非常简单的 Spark 作业 但似乎无法从我的脚本中获取任何日志输出 我尝试过打印到 stderr from pyspark import SparkContext import sys if name m
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • 如何在ipywidget按钮中显示全文?

    我正在创建一个ipywidget带有一些文本的按钮 但按钮中未显示全文 我使用的代码如下 import ipywidgets as widgets from IPython display import display button wid
  • 在Python中获取文件描述符的位置

    比如说 我有一个原始数字文件描述符 我需要根据它获取文件中的当前位置 import os psutil some code that works with file lp lib open path to file p psutil Pro
  • 使用 \r 并打印一些文本后如何清除控制台中的一行?

    对于我当前的项目 有一些代码很慢并且我无法使其更快 为了获得一些关于已完成 必须完成多少的反馈 我创建了一个进度片段 您可以在下面看到 当你看到最后一行时 sys stdout write r100 80 n I use 80覆盖最终剩余的
  • 对年龄列进行分组/分类

    我有一个数据框说df有一个柱子 Ages gt gt gt df Age 0 22 1 38 2 26 3 35 4 35 5 1 6 54 我想对这个年龄段进行分组并创建一个像这样的新专栏 If age gt 0 age lt 2 the
  • 为字典中的一个键附加多个值[重复]

    这个问题在这里已经有答案了 我是 python 新手 我有每年的年份和值列表 我想要做的是检查字典中是否已存在该年份 如果存在 则将该值附加到特定键的值列表中 例如 我有一个年份列表 并且每年都有一个值 2010 2 2009 4 1989
  • Conda SafetyError:文件大小不正确

    使用创建 Conda 环境时conda create n env name python 3 6 我收到以下警告 Preparing transaction done Verifying transaction SafetyError Th
  • 使用 Python 绘制 2D 核密度估计

    I would like to plot a 2D kernel density estimation I find the seaborn package very useful here However after searching
  • Scrapy:如何使用元在方法之间传递项目

    我是 scrapy 和 python 的新手 我试图将 parse quotes 中的项目 item author 传递给下一个解析方法 parse bio 我尝试了 request meta 和 response meta 方法 如 sc
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • 在 Qt 中自动调整标签文本大小 - 奇怪的行为

    在 Qt 中 我有一个复合小部件 它由排列在 QBoxLayouts 内的多个 QLabels 组成 当小部件调整大小时 我希望标签文本缩放以填充标签区域 并且我已经在 resizeEvent 中实现了文本大小的调整 这可行 但似乎发生了某
  • 使用 Python 的 matplotlib 选择在屏幕上显示哪些图形以及将哪些图形保存到文件中

    我想用Python创建不同的图形matplotlib pyplot 然后 我想将其中一些保存到文件中 而另一些则应使用show 命令 然而 show 显示all创建的数字 我可以通过调用来避免这种情况close 创建我不想在屏幕上显示的绘图
  • Rocket UniData/UniVerse:ODBC 无法分配足够的内存

    每当我尝试使用pyodbc连接到 Rocket UniData UniVerse 数据时我不断遇到错误 pyodbc Error 00000 00000 Rocket U2 U2ODBC 0302810 Unable to allocate
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2

随机推荐