【PyTorch】nn.TransformerEncoder 使用 src_key_padding_mask 时出现nan

2023-05-16

问题描述:

        在使用nn.TransformerEncoder时,不使用src_key_padding_mask,编码的输出正常,使用src_key_padding_mask后编码结果变成nan了。

ego_transformer_encoder = nn.TransformerEncoder(ego_encoder_layer, num_layers=6)
ego_transformer_features = ego_transformer_encoder(ego_seq2, src_key_padding_mask=src_padding_mask)

分析解决:

        出现nan的原因来自于src_key_padding_mask,src_key_padding_mask 是一个二值化的tensor,在需要被忽略地方应该是True,在需要保留原值的情况下,是False。检查发现src_key_padding_mask全为True,此时会导致编码后结果全为nan。

        解决方法是更新mask或不使用mask。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【PyTorch】nn.TransformerEncoder 使用 src_key_padding_mask 时出现nan 的相关文章

随机推荐