在您的实施中,scaled_dot_product
你缩放了query
但根据原始论文,他们使用了key
正常化。除此之外,这个实现看起来不错,但不通用。
class MultiAttention(tf.keras.layers.Layer):
def __init__(self, num_of_heads, out_dim):
super(MultiAttention,self).__init__()
self.out_dim = out_dim
self.num_of_heads = num_of_heads
self.depth = self.out_dim // self.num_of_heads
self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
self.wo = tf.keras.layers.Dense(self.out_dim)
def call(self,x):
multi_attn = []
for i in range(self.num_of_heads):
Q = self.wq[i](x)
K = self.wk[i](x)
V = self.wv[i](x)
multi_attn.append(self.scaled_dot_product(Q,K,V))
multi_head = tf.concat(multi_attn, axis=-1)
multi_head_attention = self.wo(multi_head)
return multi_head_attention
def scaled_dot_product(self, q,k,v):
qkt = tf.matmul(q, k, transpose_b=True)
dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
scaled_qkt = qkt/dk
softmax = tf.nn.softmax(scaled_qkt, axis=-1)
z = tf.matmul(softmax, v)
return z
multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape
一般变压器架构可以如下所示,其中前两个线性层代表query
and key
并负责生产注意力权重图然后加权value
以矩阵乘法的方式。
图片来源 https://www.youtube.com/watch?v=mMa2PmYJlCo.
我知道您正在尝试最小化原始 TF教程代码 https://www.tensorflow.org/tutorials/text/transformer但我认为你应该首先添加对你原来问题的参考。在最初的实现中,他们还返回了加权概率或分数随着加权特征图。我认为你不应该跳过它。
The 原始代码 https://www.tensorflow.org/tutorials/text/transformer您所关注的是更通用和高效的优化。
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def scaled_dot_product_attention(self, q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None: scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
FYI, in TF 2.4
, the tf.keras.layers.MultiHeadAttention https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention图层正式添加。
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)
您可以按如下方式测试这两个:
# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
# built-in layer
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))