【Pytorch】tensor索引另一个tensor(tensor[tensor])

2023-05-16

说明

最近在使用pytorch中tensor的时候,遇到了需要使用tensor1[tensor2]的情况,发现了这篇博客,他从代码的角度解释了其工作原理,这里我用图示的方式解释tensor类型为longTensor时的情况

工作原理

  1. 首先,创建两个tensor,如下:
    import numpy as np
    import torch
    x1 = np.array([[1,0,0,1],[1,1,0,1],[0,1,1,0],[0,1,1,1],[1,1,0,0]])
    x2 = np.array([[1,0,1],[1,1,1]])
    
    x1= torch.from_numpy(x1).long()
    x2= torch.from_numpy(x2).long()
    
    """
    x1:tensor([[1, 0, 0, 1],
            [1, 1, 0, 1],
            [0, 1, 1, 0],
            [0, 1, 1, 1],
            [1, 1, 0, 0]])  
    x1.shape:torch.Size([5, 4])
    x2:tensor([[1, 0, 1],
            [1, 1, 1]])
    x2.shape:torch.Size([2, 3])
    x1[x2]:tensor([[[1, 1, 0, 1],
    	         [1, 0, 0, 1],
    	         [1, 1, 0, 1]],
    	
    	        [[1, 1, 0, 1],
    	         [1, 1, 0, 1],
    	         [1, 1, 0, 1]]])
    """
    
  2. 说明一点就是x1[x2].shap不一定等于x2[x1].shape,那么我如何在不计算的前提下知道LongTensor1[LongTensor2].shape呢,以二维LongTensor为例,可推出公式如下:LongTensor1[LongTensor2].shape=torch.Size([LongTensor2.shape[0],LongTensor2.shape[1],LongTensor1.shape[1]),即上述
    • x1[x2].shape=torch.Size([2, 3, 4])
    • x2[x1].shape=torch.Size([5, 4, 3])
  3. 动态图说明
    动态图演示
    可以看到由于x2只包含0和1,所以在取索引的时候,只会索引x1的前两维
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch】tensor索引另一个tensor(tensor[tensor]) 的相关文章

随机推荐