torch.stack() 和 torch.cat() 函数有什么区别?

2024-04-30

OpenAI 的强化学习 REINFORCE 和 actor-critic 示例具有以下代码:

加强 https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py:

policy_loss = torch.cat(policy_loss).sum()

演员评论家 https://github.com/pytorch/examples/blob/master/reinforcement_learning/actor_critic.py:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

一种正在使用torch.cat https://pytorch.org/docs/stable/generated/torch.cat.html,其他用途torch.stack https://pytorch.org/docs/stable/generated/torch.stack.html,对于类似的用例。

据我的理解,该文档没有给出它们之间的任何明确区别。

我很高兴知道这些功能之间的差异。


stack https://pytorch.org/docs/stable/generated/torch.stack.html

沿 a 连接张量序列新维度.

cat https://pytorch.org/docs/stable/generated/torch.cat.html

连接给定的 seq 张量序列在给定维度.

So if A and B形状为 (3, 4):

  • torch.cat([A, B], dim=0)形状为 (6, 4)
  • torch.stack([A, B], dim=0)形状为 (2, 3, 4)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.stack() 和 torch.cat() 函数有什么区别? 的相关文章

随机推荐