矢量化解决方案rand+argsort
trick
我们可以沿着指定的轴生成唯一的索引,并使用以下命令索引到输入数组中advanced-indexing
。为了生成唯一索引,我们将使用random float generation + sort trick https://stackoverflow.com/a/45438143/,从而给我们一个矢量化的解决方案。我们还将对其进行概括以涵盖通用的n-dim
数组和泛型axes
with np.take_along_axis https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.take_along_axis.html。最终的实现看起来像这样 -
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
请注意,此随机播放不会就地进行,并返回一个随机播放的副本。
样本运行 -
In [33]: a
Out[33]:
array([[18, 95, 45, 33],
[40, 78, 31, 52],
[75, 49, 42, 94]])
In [34]: shuffle_along_axis(a, axis=0)
Out[34]:
array([[75, 78, 42, 94],
[40, 49, 45, 52],
[18, 95, 31, 33]])
In [35]: shuffle_along_axis(a, axis=1)
Out[35]:
array([[45, 18, 33, 95],
[31, 78, 52, 40],
[42, 75, 94, 49]])