pytorch / numpy 中具有任意和可变维数的部分切片

2024-01-02

给定 numpy(或 pytorch)中的二维张量,我可以同时沿所有维度进行部分切片,如下所示:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

如果我在实现时不知道维数,如何实现相同的切片模式,无论张量中的维数如何? (即我想要a[1:] if a只有一维,a[1:,1:]对于二维,a[1:,1:,1:]对于三个维度,依此类推)

如果我可以用如下所示的一行代码来完成它,那就太好了,但这是无效的:

a[(1:,) * len(a.shape)]  # SyntaxError: invalid syntax

我对适用于 pytorch 张量的解决方案特别感兴趣(只需将上面的 numpy 替换为 torch,示例是相同的),但我认为如果该解决方案同时适用于 numpy 和 pytorch,那么它可能也是最好的。


Answer:制作一个元组slice https://docs.python.org/3/library/functions.html#slice对象可以解决这个问题:

a[(slice(1,None),) * len(a.shape)]

解释: slice是一个内置的 python 类(不依赖于 numpy 或 pytorch),它提供了用于描述切片的下标表示法的替代方法。答案 https://stackoverflow.com/a/12616901/3780389 to 另一个问题 https://stackoverflow.com/q/12616821/3780389建议使用它作为在 python 变量中存储切片信息的方式。这蟒蛇术语表 https://docs.python.org/3/glossary.html#term-slice指出

括号(下标)表示法使用slice https://docs.python.org/3/library/functions.html#slice内部对象。

自从__getitem__方法用于numpy ndarrays https://docs.scipy.org/doc/numpy/user/basics.indexing.html and 火炬张量 https://pytorch.org/docs/stable/tensors.html#torch.Tensor支持切片的多维索引,它们也必须支持切片对象的多维索引,因此我们可以将这些切片创建一个具有正确长度的元组。

顺便说一句,您可以通过创建一个虚拟类来了解 python 如何使用切片对象,如下所示,然后对其进行切片:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch / numpy 中具有任意和可变维数的部分切片 的相关文章

随机推荐