torch.unique() 中的参数“dim”如何工作?

2024-03-13

我试图提取矩阵每一行中的唯一值并将它们返回到同一个矩阵中(重复值设置为 0)例如,我想转换

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

to

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

or

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

IE。行中的顺序并不重要。我尝试过使用pytorch.unique()并且在文档中提到可以使用参数指定采用唯一值的维度dim。然而,它似乎不适用于这种情况。

我试过了:

output= torch.unique(torch.Tensor([[4,2,52,2,2],[5,2,6,6,5]]), dim = 1)

output

这使

tensor([[ 2.,  2.,  2.,  4., 52.],
        [ 2.,  5.,  6.,  5.,  6.]])

有人对此有特别的解决办法吗?如果可能的话,我会尽量避免 for 循环。


人们必须承认unique如果没有给出适当的示例和解释,函数有时可能会非常混乱。

The dim参数指定要应用到矩阵张量的哪个维度。

例如,在二维矩阵中,dim=0将使操作垂直执行,其中dim=1意思是水平的。

例如,让我们考虑一个 4x4 矩阵dim=1。正如你从我下面的代码中看到的,unique操作是逐行应用的。

您注意到该数字两次出现11在第一行和最后一行。 Numpy 和 Torch 这样做是为了保留最终矩阵的形状。

但是,如果您没有指定任何维度,Torch 会自动展平您的矩阵,然后应用unique到它,你将得到一个包含唯一数据的一维数组。

import torch

m = torch.Tensor([
    [11, 11, 12,11], 
    [13, 11, 12,11], 
    [16, 11, 12, 11],  
    [11, 11, 12, 11]
])

output, indices = torch.unique(m, sorted=True, return_inverse=True, dim=1)
print("Ori \n{}".format(m.numpy()))
print("Sorted \n{}".format(output.numpy()))
print("Indices \n{}".format(indices.numpy()))

# without specifying dimension
output, indices = torch.unique(m, sorted=True, return_inverse=True)
print("Sorted (no dim) \n{}".format(output.numpy()))

结果(暗淡=1)

Ori
[[11. 11. 12. 11.]
 [13. 11. 12. 11.]
 [16. 11. 12. 11.]
 [11. 11. 12. 11.]]
Sorted
[[11. 11. 12.]
 [11. 13. 12.]
 [11. 16. 12.]
 [11. 11. 12.]]
Indices
[1 0 2 0]

结果(无维度)

Sorted (no dim)
[11. 12. 13. 16.]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.unique() 中的参数“dim”如何工作? 的相关文章

随机推荐