我有三个张量,A, B and C
在张量流中,A
and B
都是形状(m, n, r)
, C
是形状的二元张量(m, n, 1)
.
我想根据以下值从 A 或 B 中选择元素C
。显而易见的工具是tf.select
,但是它没有广播语义,所以我需要首先显式广播C
与 A 和 B 的形状相同。
这将是我第一次尝试如何做到这一点,但它不喜欢我混合张量(tf.shape(A)[2]
) 进入形状列表。
import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))
C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)
这里正确的做法是什么?