Multihead Attention - 多头注意力

2023-11-20

多头注意力

在实践中,当给定 相同的查询、键和值的集合 时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces) 可能是有益的。

为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h h h 组不同的线性投影(linear projections) 来变换查询、键和值。然后,这 h h h 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h h h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h h h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)

本质地讲,自注意力机制是:通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。

自注意力机制的缺陷是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此作者提出了通过多头注意力机制来解决这一问题。

下图展示了使用全连接层来实现可学习的线性变换的多头注意力。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-R7BJtkT1-1667357320669)(attachment:QQ%E6%88%AA%E5%9B%BE20221031074721.png)]

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} qRdq、键 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} kRdk和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} vRdv,每个注意力头 h i \mathbf{h}_i hi i = 1 , … , h i = 1, \ldots, h i=1,,h)的计算方法为:

h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv,

其中,可学习的参数包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)Rpq×dq W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)Rpk×dk W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)Rpv×dv,以及代表注意力汇聚的函数 f f f
f f f 可以是之前学习的加性注意力缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h 个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} WoRpo×hpv

W o [ h 1 ⋮ h h ] ∈ R p o . \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. Wo h1hh Rpo.

基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h 个头。在下面的实现中, p o p_o po是通过参数 num_hiddens 指定的。

class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        # queries, keys, values的形状:
        # (batch_size,查询或“键-值”对的个数,num_hiddens)
        # valid_len 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values的形状:
        # (batch_size*num_heads,查询或“键-值”个数,num_hiddens/num_head)
        
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
        if valid_lens is not None:
            # 在轴0,将第一项(标量或矢量) 复制 num_heads次,
            # 然后如此复制第二项,然后诸如此类
            valid_lens = torch.repeat_interleave(valid_lens,
                                                repeats=self.num_heads,
                                                dim=0)
        
        
        # output的形状:(batch_size*num_heads, 查询个数,num_hiddens/num_head)
        output = self.attention(queries, keys, values, valid_lens)
        # output_concat的形状:(batch_size, 查询个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
            

为了能够使多个头并行计算,上面的 MultiHeadAttention 类将使用下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作。

def transpose_qkv(X, num_heads):
    """为了多头注意力的并行计算而变换形状"""
    # 输入X的形状(batch_size, 查询或”键-值“对的个数,num_hiddens)
    # 输出X的形状(batch_size,查询或”键-值“对的个数,
    # num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    
    # 输出X的形状(batch_size,
    # num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    
    # 输出X的形状(batch_size*num_heads,
    # 查询或”键-值“对的个数,num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # 输入X的形状(batch_size*num_heads,
    # 查询或”键-值“对的个数,num_hiddens/num_heads)
    
    # 输出X的形状(batch_size,
    # num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    
    # 输出X的形状(batch_size,查询或”键-值“对的个数,
    # num_heads,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    
    # 输出X的形状(batch_size,查询或”键-值“对的个数,num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)

下面我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size,num_queries, num_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                              num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

1、多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

2、基于适当的张量操作,可以实现多头注意力的并行计算。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Multihead Attention - 多头注意力 的相关文章

  • Django REST序列化器:创建对象而不保存

    我已经开始使用 Django REST 框架 我想做的是使用一些 JSON 发布请求 从中创建一个 Django 模型对象 然后使用该对象而不保存它 我的 Django 模型称为 SearchRequest 我所拥有的是 api view
  • 如何收集列表、字典等中重复计算的结果(或制作修改每个元素的列表的副本)?

    There are a great many existing Q A on Stack Overflow on this general theme but they are all either poor quality typical
  • Python 多处理示例不起作用

    我正在尝试学习如何使用multiprocessing但我无法让它发挥作用 这是代码文档 http docs python org 2 library multiprocessing html from multiprocessing imp
  • 如何在Windows上模拟socket.socketpair

    标准Python函数套接字 套接字对 https docs python org 3 library socket html socket socketpair不幸的是 它在 Windows 上不可用 从 Python 3 4 1 开始 我
  • SQL Alchemy 中的 NULL 安全不等式比较?

    目前 我知道如何表达 NULL 安全的唯一方法 SQL Alchemy 中的比较 其中与 NULL 条目的比较计算结果为 True 而不是 NULL 是 or field None field value 有没有办法在 SQL Alchem
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • Python:字符串不会转换为浮点数[重复]

    这个问题在这里已经有答案了 我几个小时前写了这个程序 while True print What would you like me to double line raw input gt if line done break else f
  • 使用 OpenPyXL 迭代工作表和单元格,并使用包含的字符串更新单元格[重复]

    这个问题在这里已经有答案了 我想使用 OpenPyXL 来搜索工作簿 但我遇到了一些问题 希望有人可以帮助解决 以下是一些障碍 待办事项 我的工作表和单元格数量未知 我想搜索工作簿并将工作表名称放入数组中 我想循环遍历每个数组项并搜索包含特
  • Python - 按月对日期进行分组

    这是一个简单的问题 起初我认为很简单而忽略了它 一个小时过去了 我不太确定 所以 我有一个Python列表datetime对象 我想用图表来表示它们 x 值是年份和月份 y 值是此列表中本月发生的日期对象的数量 也许一个例子可以更好地证明这
  • Python - 在窗口最小化或隐藏时使用 pywinauto 控制窗口

    我正在尝试做的事情 我正在尝试使用 pywinauto 在 python 中创建一个脚本 以在后台自动安装 notepad 隐藏或最小化 notepad 只是一个示例 因为我将编辑它以与其他软件一起使用 Problem 问题是我想在安装程序
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 如何改变Python中特定打印字母的颜色?

    我正在尝试做一个简短的测验 并且想将错误答案显示为红色 欢迎来到我的测验 您想开始吗 是的 祝你好运 法国的首都是哪里 法国 随机答案不正确的答案 我正在尝试将其显示为红色 我的代码是 print Welcome to my Quiz be
  • Python:计算字典的重复值

    我有一本字典如下 dictA unit1 test1 alpha unit1 test2 beta unit2 test1 alpha unit2 test2 gamma unit3 test1 delta unit3 test2 gamm
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 在python中,如何仅搜索所选子字符串之前的一个单词

    给定文本文件中的长行列表 我只想返回紧邻其前面的子字符串 例如单词狗 描述狗的单词 例如 假设有这些行包含狗 hotdog big dog is dogged dog spy with my dog brown dogs 在这种情况下 期望
  • 协方差矩阵的对角元素不是 1 pandas/numpy

    我有以下数据框 A B 0 1 5 1 2 6 2 3 7 3 4 8 我想计算协方差 a df iloc 0 values b df iloc 1 values 使用 numpy 作为 cov numpy cov a b I get ar
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • js制作简易计算器

    实现2个输入框中输入整数后 点击第三个输入框能给出2个整数的加减乘除 提示 获取元素的值设置和获取方法为 例 赋值 document getElementById id value 1 取值 var document getElementB
  • oracle 建表 提示 ora-00955:名称已由现有对象使用.

    问题 oracle 执行 Create table 设备执行库房 设备id number 18 执行库房id number 18 提示 ora 00955 名称已由现有对象使用 但是执行 drop table 设备执行库房 却提示表或视图不
  • java生成随机数组_JAVA生成随机数组10个数字并求和

    JAVA生成随机数组10个数字并求和 本文最终结果大概是这样的 使用java技术随机生成10个数 然后填充一个数组并在消息框中显示数组内容 接着对数组求和输出 将结果显示在消息框中 设计思路 可以先用Math Random 1000生成10
  • hive-05-Execution Error, return code 3 from org.apache.hadoop.hive.ql.exec.mr.MapredLocalTask

    hive命令行里执行了一句话 select from person join zhanghao on person zjhm zhanghao zjhm limit 100 就是两个表做连接查询 数据量大小一个是3千万 一个是3亿 结果报错
  • 【Transformer系列】深入浅出理解Transformer网络模型(综合篇)

    一 参考资料 The Illustrated Transformer 图解Transformer 完整版 Attention Is All You Need The Core Idea of the Transformer transfor
  • 【含源码】两种不同风格的圣诞树代码合集,其中还有可以改名字的圣诞树代码

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 前言 前言 一年一度的圣诞节马上就要到了 看到好多程序员小伙伴已经开始炫耀自己制作的圣诞树了 今天就跟大家分享2种不同风格的圣诞树 附上完整代码 拿来即用可以按照自
  • Linux常用命令记录

    文章目录 1 软件安装 安装软件 来自源服务器 安装 deb软件 来自本地 deb文件 修复依赖关系 卸载软件 2 文件 文件夹操作 删除文件夹 移动文件 文件重命名 3 程序查看 处理 进程查看 查看端口占用情况 强制终止程序 4 解压文
  • 肖sir__mysql之单表__004

    mysql之单表 一 建表语句 1 show databases 查看所有的数据库 2 create databaes 数据库名 创建数据库 3 use 数据库名 指定使用数据库 4 show tables 5 创建表 格式 create
  • linux计算字符串个数,Linux 统计某个字符串个数的方法

    在 Linux 系统下 有时候 我们可能要对一个日志文件进行分析 比如 分析日志文件中某个单词或者某个特殊字符串出现了多少次 对于匹配统计 一般用到正则方法 下面总结了几个统计字符串个数的方法 方法一 使用 grep 命令 grep o 字
  • Python自动化测试 软件测试最全教程(附笔记),看完就可就业

    最近看到很多粉丝在后台私信我 叫我做一期Python自动化测试的教程 其实关于这个问题 我也早就在着手准备了 我录制了一整套完整的Python自动化测试的教程 都上传在B站上面 大家有兴趣的可以去看一下 Python自动化测试 手把手教你做
  • springboot不香吗?为什么还要使用springcloud

    1 为什么要使用springcloud 如果我们的服务需要调用另外的一个服务 当然可以通过url 加接口直接调用 但是如果url变动后 我们也要跟着修改 还有可能服务宕机我们也不知道 而且现在只有一个url不具备高可用性 就算有多个url
  • Hudi Log 文件格式与读写流程

    Hudi Log 文件格式与读写流程 背景 对 Hudi 有一定了解的读者应该知道 Hudi 有 COW 和 MOR 两种表类型 其中的 MOR 表会通过日志文件记录文件 写入一个 MOR 表后产生的文件可以观察到 一个 MOR 表数据存储
  • 【LeetCode与《代码随想录》】字符串篇:做题笔记与总结-JavaScript版

    文章目录 代码随想录 主要题目 344 反转字符串 541 反转字符串 II 剑指 Offer 05 替换空格 151 反转字符串中的单词 剑指 Offer 58 II 左旋转字符串 28 找出字符串中第一个匹配项的下标 KMP 还没写 4
  • 我最喜欢的10个顶级数据科学资源,kaggle、TDS、arXiv......

    当我声明数据科学正在成为最受欢迎的工作领域之一时 我想你不会与我争辩 特别是考虑到 哈佛商业评论 将 数据科学家 评为21世纪最性感的工作 在这个领域 我们已经走过了很长的路 从数据科学和机器学习等术语还不为人所知 到一切都聚集在统计学的保
  • systemd[1]: Failed to load SELinux policy. freezing.

    今天早上发现centos7无法启动了 界面提示systemd 1 Failed to load SELinux policy freezing 查到一篇资料说是selinux设置出问题了 他将 etc selinux config文件中的s
  • MATLAB进行模式识别的实验

    一 实验一习题 我猜测是根据最大似然估计法先求出那两个参数的值 然后代入 得到的是只关于x的函数 然后把文本里的1000个数据导入 画图 首先 我先把txt的数据读取到矩阵里面 方便后续处理 用到的函数 1 这里有一个比较详细的fopen的
  • docker部署war包、将容器打包成镜像、镜像导出到本地、镜像推送到dockerhub

    前言 最近公司使用帆软 finereport 报表工具制作数据报表 并且需要将制作好的报表打包成war包通过docker部署 并且将部署好的项目制作成docker镜像 发给客户 下面将部署过程中踩的坑总结一下 想要了解帆软可以点击官方链接查
  • 图片上传服务器系统说明

    图片服务器测试用例 图片上传服务器系统说明 数据库设计 drop database if exists drawing bed create database drawing bed character set utf8mb4 use dr
  • 东风小康为什么是dfsk_自吸这么“香”,为什么现在新车都是涡轮增压

    知乎视频 www zhihu com 开车不带 T 干啥都没劲 车子用了涡轮增压能够显著提升动力 能把一台 能用 的车变成 好用 的车 并且国内的排放法规也越来越严格 使用涡轮增压的同时 也具备了一些节能减排的效果 所以说 自然吸气的车越来
  • Multihead Attention - 多头注意力

    文章目录 多头注意力 模型 实现 小结 多头注意力 在实践中 当给定 相同的查询 键和值的集合 时 我们希望模型可以基于相同的注意力机制学习到不同的行为 然后将不同的行为作为知识组合起来 捕获序列内各种范围的依赖关系 例如 短距离依赖和长距