您需要将第一个维度的索引添加到query
为了使用它tf.gather_nd
。这是一种方法:
import tensorflow as tf
import numpy as np
np.random.seed(100)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.placeholder(tf.float32, [None, 368, 5])
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
n = tf.shape(params)[0]
# Make tensor of indices for the first dimension
ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
# Stack indices
idx = tf.stack([ii, query], axis=-1)
# Gather reordered tensor
result = tf.gather_nd(params, idx)
# Test
out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
# Check the order is correct
print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
# True