我在文章的第三部分分析了这种行为“分析 tf.function 以发现 Autograph 的优势和微妙之处”(我强烈建议阅读所有 3 部分,以了解如何在用tf.function
- 答案底部的链接)。
For the __eq__
and tf.equal
问题,答案是:
简而言之:__eq__
运算符(对于tf.Tensor
) 已被覆盖,但运算符不使用tf.equal
为了检查 Tensor 相等性,它只检查 Python 变量标识(如果您熟悉 Java 编程语言,这与字符串对象上使用的 == 运算符完全相同)。原因是,tf.Tensor
对象需要是可哈希的,因为它在 Tensorflow 代码库中的任何地方都用作 dict 对象的键。
而对于所有其他运算符,答案是 AutoGraph 不会将 Python 运算符转换为 TensorFlow 逻辑运算符。在本节中AutoGraph 如何(不)转换运算符我展示了每个 Python 运算符都会转换为始终被评估为 false 的图形表示形式。
事实上,以下示例生成输出“wat”
@tf.function
def if_elif(a, b):
if a > b:
tf.print("a > b", a, b)
elif a == b:
tf.print("a == b", a, b)
elif a < b:
tf.print("a < b", a, b)
else:
tf.print("wat")
x = tf.constant(1)
if_elif(x,x)
在实践中,AutoGraph无法将Python代码转换为图形代码;我们必须仅使用 TensorFlow 原语来帮助它。在这种情况下,您的代码将按您的预期工作。
@tf.function
def if_elif(a, b):
if tf.math.greater(a, b):
tf.print("a > b", a, b)
elif tf.math.equal(a, b):
tf.print("a == b", a, b)
elif tf.math.less(a, b):
tf.print("a < b", a, b)
else:
tf.print("wat")
我在这里提供了所有三篇文章的链接,我想您会发现它们很有用:
part 1, part 2, part 3