如果将两个维度为 n * c * h * w 进行拼接,如果指定dim为以下值
- dim = 0, 拼接后维度为 2n * c * h * w
- dim = 1, 拼接后维度为 n * 2c * h * w
- dim = 2,拼接后维度为 n * c * 2h * w
- dim = 3,拼接后维度为 n * c * h * 2w
即 dim = i 就表示在第 i 维度度进行拼接,此时除第 i 维度数可以不同外, 其他维度必须相同, 否则无法拼接。
测试1:
x1 = torch.rand((1, 16, 32, 32))
y1 = torch.rand((1, 32, 32, 32))
- 在 dim = 0 拼接
out0 = torch.cat((x1, y1), dim = 0)
报错:RuntimeError: Sizes of tensors must match except in dimension 0. Got 16 and 32 in dimension 1 (The offending index is 1)
即 x1, x2 在其他维度不相等(x1(16, 32, 32),x2(32, 32, 32))
2. 在 dim = 1 拼接
out1 = torch.cat((x1, y1), dim = 1)
print(out1.size())
输出: torch.Size([1, 48, 32, 32]),即在dim = 1 上拼接后为 16 + 32 = 48
同理,在dim = 2 或者 dim = 3 维度拼接都会出错
测试1:测试两个维度一模一样的张量
in_put1 = torch.rand((1, 64, 8, 8))
in_put2 = torch.rand((1, 64, 8, 8))
out0 = torch.cat((in_put1, in_put2), dim = 0)
out1 = torch.cat((in_put1, in_put2), dim = 1)
out2 = torch.cat((in_put1, in_put2), dim = 2)
out3 = torch.cat((in_put1, in_put2), dim = 3)
我们的预期结果当然是
out0: (2, 64, 8, 8)
out1: (1, 128, 8, 8)
out2: (1, 64, 16, 6)
out2: (1, 64, 8, 16)
查看打印结果的确如此: