看一眼torch.nonzero
这大致相当于np.where
。它将二进制掩码转换为索引:
>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)
>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
[1],
[3],
[4],
[5]])
>>> X[indices]
tensor([[0.1000],
[0.5000],
[0.0000],
[1.2000],
[0.0000]])
解决方案是这样写:
mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)