假设我有这两个张量:
-
valueMatrix
,形状为(?, 3)
, where ?
是批量大小
-
indexMatrix
,形状为(?, 1)
我想从中检索值valueMatrix
在包含的索引处indexMatrix
.
示例(伪代码):
valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int
我想从这个例子中做一些类似的事情:
valueMatrix[indexMatrix] --> returns --> [[15],[4]]
与其他后端相比,我更喜欢 Tensorflow,但答案必须与使用 Lambda 层或其他适合任务的层的 Keras 模型兼容。
import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])
# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)
with tf.Session() as sess:
print(sess.run(values))
#[[15]
# [ 4]]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)