偶然发现pytorch的tensor除了像numpy数组那样简单索引或者切片,还有一种花式索引,也就是用tensor对tensor索引
,可以广播原tensor。下面给出示例和转为numpy版本的写法。
示例
i
n
a
.
s
h
a
p
e
=
[
b
,
c
,
h
,
w
]
in_a.shape=[b,c,h,w]
ina.shape=[b,c,h,w]
i
n
b
.
s
h
a
p
e
=
[
m
,
n
]
in_b.shape= [m,n]
inb.shape=[m,n]
采用in_b对in_a索引:
o
u
t
=
a
[
:
,
:
,
b
,
:
]
out = a[:,:, b,:]
out=a[:,:,b,:]
则得到的out的shape:
o
u
t
.
s
h
a
p
e
=
[
b
,
c
,
m
,
n
,
w
]
out.shape=[b,c,m,n,w]
out.shape=[b,c,m,n,w]
举个例子:
>>> in_a = torch.randn(1,1,4,5)
>>> in_b = torch.tensor([[2,0],[1,3],[2,3]])
>>> in_a
tensor([[[[ 0.2668, 0.5453, 0.5563, 0.7396, -1.1646],
[-0.1059, 0.8955, 0.8947, -3.0298, -2.0912],
[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]]]])
>>> in_b.shape
torch.Size([3, 2])
>>> in_a[:,:,in_b,:].shape
torch.Size([1, 1, 3, 2, 5])
>>> in_a[:,:,in_b,:]
tensor([[[[[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 0.2668, 0.5453, 0.5563, 0.7396, -1.1646]],
[[-0.1059, 0.8955, 0.8947, -3.0298, -2.0912],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]],
[[ 0.8145, 0.3670, 0.4827, 0.1327, -0.9437],
[ 1.3698, -0.8281, -0.8810, 1.6670, -1.8736]]]]])
也就是在
i
n
_
a
in\_a
in_a的
d
i
m
=
2
dim= 2
dim=2 上索引,依次取index=
[
2
,
0
]
,
[
1
,
3
]
,
[
2
,
3
]
[2,0],[1,3],[2,3]
[2,0],[1,3],[2,3]的tensor填充。特别要注意:index的数值不能超出dim=2的最大维度, 比如例子中,in_a的shape为
[
1
,
1
,
4
,
5
]
[1,1,4,5]
[1,1,4,5],在dim=2维度索引, 则索引的值只能是
0
,
1
,
2
,
3
0,1,2,3
0,1,2,3.
再举个栗子:
>>> in_a[:,:,:,in_b].shape
torch.Size([1, 1, 4, 3, 2])
>>> in_a[:,:,:,in_b]
tensor([[[[[ 0.5563, 0.2668],
[ 0.5453, 0.7396],
[ 0.5563, 0.7396]],
[[ 0.8947, -0.1059],
[ 0.8955, -3.0298],
[ 0.8947, -3.0298]],
[[ 0.4827, 0.8145],
[ 0.3670, 0.1327],
[ 0.4827, 0.1327]],
[[-0.8810, 1.3698],
[-0.8281, 1.6670],
[-0.8810, 1.6670]]]]])
用numpy写花式索引
目前只想到很愚蠢的遍历读取再赋值:
import numpy as np
num_a = in_a.numpy()
num_b = in_b.numpy()
[b,c,h,w] = num_a.shape
[m,n] = num_b.shape
out_ny = np.zeros([b,c,m,n,w])
for i in range(m):
for j in range(n):
out_ny[:,:,i,j,:] = num_a[:,:, num_b[i,j],:]
out_ny
array([[[[[ 0.81448293, 0.36703789, 0.48273084, 0.13274327,
-0.94368148],
[ 0.26677063, 0.54529017, 0.55633378, 0.73956281,
-1.16463828]],
[[-0.10586801, 0.89547068, 0.89467597, -3.02978396,
-2.09123206],
[ 1.36978781, -0.8280825 , -0.8810119 , 1.6670413 ,
-1.87361884]],
[[ 0.81448293, 0.36703789, 0.48273084, 0.13274327,
-0.94368148],
[ 1.36978781, -0.8280825 , -0.8810119 , 1.6670413 ,
-1.87361884]]]]])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)