验证 Transformer 中多头注意力的实现

2024-05-20

我已经实施了MultiAttention head in Transformers。周围有太多的实现,所以很混乱。有人可以验证我的实施是否正确:

DotProductAttention 引用自:https://www.tensorflow.org/tutorials/text/transformer#setup https://www.tensorflow.org/tutorials/text/transformer#setup

import tensorflow as tf

def scaled_dot_product(q,k,v):
    #calculates Q . K(transpose)
    qkt = tf.matmul(q,k,transpose_b=True)
    #caculates scaling factor
    dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
    scaled_qkt = qkt/dk
    softmax = tf.nn.softmax(scaled_qkt,axis=-1)
    
    z = tf.matmul(softmax,v)
    #shape: (m,Tx,depth), same shape as q,k,v
    return z

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self,d_model,num_of_heads):
        super(MultiAttention,self).__init__()
        self.d_model = d_model
        self.num_of_heads = num_of_heads
        self.depth = d_model//num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(d_model)
        
    def call(self,x):
        
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(scaled_dot_product(Q,K,V))
            
        multi_head = tf.concat(multi_attn,axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

#Calling the attention 
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)

在您的实施中,scaled_dot_product你缩放了query但根据原始论文,他们使用了key正常化。除此之外,这个实现看起来不错,但不通用。

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self, num_of_heads, out_dim):
        super(MultiAttention,self).__init__()
        self.out_dim      = out_dim
        self.num_of_heads = num_of_heads
        self.depth        = self.out_dim // self.num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(self.out_dim)
        
    def call(self,x):
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(self.scaled_dot_product(Q,K,V))

        multi_head = tf.concat(multi_attn, axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

    def scaled_dot_product(self, q,k,v):
        qkt = tf.matmul(q, k, transpose_b=True)
        dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
        scaled_qkt = qkt/dk
        softmax = tf.nn.softmax(scaled_qkt, axis=-1)
        z = tf.matmul(softmax, v)
        return z

multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape

一般变压器架构可以如下所示,其中前两个线性层代表query and key并负责生产注意力权重图然后加权value以矩阵乘法的方式。

图片来源 https://www.youtube.com/watch?v=mMa2PmYJlCo.

我知道您正在尝试最小化原始 TF教程代码 https://www.tensorflow.org/tutorials/text/transformer但我认为你应该首先添加对你原来问题的参考。在最初的实现中,他们还返回了加权概率或分数随着加权特征图。我认为你不应该跳过它。


The 原始代码 https://www.tensorflow.org/tutorials/text/transformer您所关注的是更通用和高效的优化。

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
        # scale matmul_qk
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        # add the mask to the scaled tensor.
        if mask is not None: scaled_attention_logits += (mask * -1e9)
        # softmax is normalized on the last axis (seq_len_k) so that the scores
        # add up to 1.
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
        output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
        return output, attention_weights

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention,  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights

FYI, in TF 2.4, the tf.keras.layers.MultiHeadAttention https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention图层正式添加。

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)

您可以按如下方式测试这两个:

# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))

# built-in layer 
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

验证 Transformer 中多头注意力的实现 的相关文章

随机推荐