也许有点晚了,但对于任何最终在这篇文章中寻找解决方案的人来说,这可能会有所帮助。
使用 Transformer 的典型场景是在 NLP 问题中,其中有成批的句子(为了简单起见,我们假设它们已经被标记化)。考虑以下示例:
sentences = [['Lorem', 'ipsum', 'dolor', 'sit', 'amet'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']]
如您所见,我们有两个长度不同的句子。为了在张量流模型中学习它们,我们可以用一个特殊的标记填充最短的一个,比方说'[PAD]'
,然后按照您的建议将它们输入到 Transformer 模型中。因此:
sentences = tf.constant([['Lorem', 'ipsum', 'dolor', 'sit', 'amet', '[PAD]', '[PAD]'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']])
还假设我们已经有了从某些语料库中提取的标记词汇表,例如词汇表1000
令牌,我们可以定义一个StringLookup
将我们的一批句子转换为给定词汇的数字投影的层。我们可以指定使用哪个令牌masking.
lookup = tf.keras.layers.StringLookup(vocabulary=vocabulary, mask_token='[PAD]')
x = lookup(sentences)
# x is a tf.Tensor([[2, 150, 19, 997, 9, 0, 0], [72, 14, 1, 1, 960, 58, 87]], shape=(2, 7), dtype=int64)
我们可以看到[PAD]
令牌映射到0词汇中的价值。
典型的下一步是将这个张量输入到Embedding
层,像这样:
embedding = tf.keras.layers.Embedding(input_dim=lookup.vocabulary_size(), output_dim=64, mask_zero=True)
这里的关键是论证mask_zero
。根据文档 https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding,这个论证的意思是:
布尔值,输入值 0 是否是一个特殊的“填充”值,应该被屏蔽掉......
这允许embedding
层为后续层生成一个掩码,以指示哪些位置应该出现,哪些位置不应该出现。该掩码可以通过以下方式访问:
mask = embedding.compute_mask(sentences)
# mask is a tf.Tensor([[True, True, True, True, True, False, False], [True, True, True, True, True, True, True]], shape=(2, 7), dtype=bool)
嵌入的张量的形式为:
y = embedding(sentences)
# y is a tf.Tensor of shape=(2, 7, 64), dtype=float32)
为了使用mask
进入MultiHeadAttention
层,必须重新调整掩模的形状才能满足形状要求,根据文档是[B, T, S]
where B
意味着批量大小(示例中为 2),T
意味着查询大小(在我们的示例中为 7),以及S
意味着key size(如果我们使用自注意力,则再次为 7)。同样在多头注意力层中,我们必须注意头的数量H
。使用此输入创建兼容掩码的最简单方法是通过广播:
mask = mask[:, tf.newaxis, tf.newaxis, :]
# mask is a tf.Tensor of shape=(2, 1, 1, 7), dtype=bool) -> [B, H, T, S]
然后我们终于可以喂食了MultiHeadAttention
层如下:
mha = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=64)
z = mha(y, y, attention_mask=mask)
所以为了使用,你的TransformerBlock
带有遮罩的图层,您应该添加到call
方法一mask
论证,如下:
def call(self, inputs, training, mask=None):
attn_output = self.att(inputs, inputs, attention_mask=mask)
...
在您调用的层/模型中MultiHeadAttention
层,您必须传递/传播使用生成的掩码Embedding
layer.