attention(注意力机制)原理和pytorch demo

2023-11-19

目录

说明

RNN的局限性

注意力机制原理

注意力机制实现

第一步:编码

第二步:第0次打分并解码

第三步:第1次打分并解码

Demo链接和结果分析

总结&改进


说明

demo源自吴恩达老师的课程,从tensorflow修改为pytorch,略有不同。

RNN的局限性

原始数据是一个字符串:friday august 17 2001,长度是21(包含空格),为了简便这里把每一个字符用一个onehot向量表示。于是数据转化为21个onehot向量。依次输入到一个RNN网络(可以是普通RNN、也可以是LSTM和GRU),最终得到一个向量(即RNN网络中的隐状态)。如果此时用这个向量作为整个字符串的编码信息直接去解码,很可能会丢失一些信息,尤其是输入更长的字符串时,更容易丢失信息。并且很难抽取距离较远的两个特征之间的关系。

注意力机制原理

我们的目标是把这个字符串翻译成2001-08-17。想象一下如果是人来进行这个翻译,那么我们会做出如下映射关系,箭头即表示人的注意力机制。神经网络的注意力机制就是在模仿人类。

注意力机制实现

第一步:编码

由于输入序列是不定长,为方便计算,将全部输入都补充到长为30,补充方式为末尾加特定字符,记为<pad>。即friday august 17 2001<pad><pad><pad><pad><pad><pad><pad><pad><pad>

然后把对应的30个onehot向量(对于其他任务,可以是不同的特征向量),依次输入到encoder网络(这里使用双向LSTM)中,每次计算得到的隐状态向量全都保存下来,一共是30个(这里LSTM的隐状态向量长度设为64,由于是双向LSTM,长度一共是128),作为初始特征,记作Feature_30x128,这里30表示时间序列长度。

第二步:第0次打分并解码

此时解码部分RNN网络的隐状态向量H初始为全零(这里向量长度是64),复制30份,然后和Feature_30x128拼起来得到Feature_30x192。然后输入到一个全连接网络,输出是30*1维矩阵,即长为30的向量,最后经过softmax,得到30个打分(softmax的目的是让30个打分之和为1)。

此时有30个长为128的初始特征,即Feature_30x128;以及30个打分,相乘后加起来,得到一个128维的打分后特征,此操作举例如下(为简便,例子中的特征维度不是30*128,是3*4,则分数有3个)。

{\color{Red} {\color{Red} }Feature\_30x128}=\begin{bmatrix} 0.1 & 1.1 &0.7 \\ 0.2 & 0.5 &1.4 \\ 0.4 & 0.3 &0.5 \\ 0.3 & 0.6 &0.2 \end{bmatrix},score=\begin{bmatrix} 0.1 & 0.7 & 0.2 \end{bmatrix}

相乘后如下。

{\color{Red} Feature\_30x128}=\begin{bmatrix} 0.01 & 0.77 &0.14 \\ 0.02 & 0.35 &0.28 \\ 0.04 & 0.21 &0.1 \\ 0.03 & 0.42 &0.04 \end{bmatrix}

然后沿着时间维度相加,得到

{\color{Blue} Feature\_128}=\begin{bmatrix} 0.92 \\ 0.65 \\ 0.35 \\ 0.49 \end{bmatrix}

Feature_128输入到解码部分RNN网络,只向前传播一次,得到新的输出隐状态H,然后在经过一层全连接,进行分类(即输出哪个字符)。

第三步:第1次打分并解码

往后的解码和第二步都一样,只不过H在不断变化,用以拼在Feature_30x128上,指导如何打分。

Demo链接和结果分析

代码链接:https://github.com/zcsxll/date_trans_with_attention

我们实际上一共解码10次(因为2001-08-17这种输出格式长度固定为10),每次都会的到一个长为30的打分,即一张10*30的热图,如下图。一共10行,每一行长是30,是解码对应字符时的打分结果。

从图中可知注意力机制的效果十分明显,当解码月份08时,august部分的打分较大。至于年份部分的打分并非一一对应,是因为训练数据集中,一旦出现零几年,就只有二零零几年。

总结&改进

编码和解码部分都采用了LSTM,其中包含C和H两个隐变量,都可以拼到Feature_30x128上进行打分计算。

输出是分类任务,但实通过实验,对于本实例,训练时采用MSE损失比采用交叉熵收敛得更快。

输入到网络中的特征不应该是单个字符,而应该是单词,例如august,应该作为一个特征向量进行操作,而不是6个。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

attention(注意力机制)原理和pytorch demo 的相关文章

随机推荐

  • 【嵌入式基础】串口中断通信VS串口DMA通信

    目录 目录 前言 一 串口通信 1 通信方式 2 通信速率 3 串口通信的三种工作方式 二 串口中断通信 1 串口中断特点 2 CubeMX配置初始化串口中断相关外设 3 串口中断程序分析 4 实验效果 三 串口DMA通信 1 关于DMA
  • 数字后端知识点扫盲——后端设计流程及使用工具

    1 DFT Design For Test 可测性设计 芯片每一步往往都自带测试电路 DFT的目的就是在设计的时候就考虑将来的测试 DFT的常见方法是 在设计中插入scan chain 将非扫描单元 如寄存器 变为扫描单元 DFT工具是sy
  • VS包含目录、库目录、附加依赖项、环境变量详解

    首先 提出一个问题 我们编译一个程序 都需要哪些文件 1 头文件 2 静态库lib 3 动态库dll 针对这三个文件 我们便可以设置工程的相关属性 1 头文件 我们要用到一个头文件 需要知道这个头文件的名字 然后用 include将它包含进
  • Java高级-包装类、BigDecimal和BigInteger

    基本数据类型和包装类 基本数据类型和包装类如下所示 基本类型 包装类 int java lang Integer 父类为java lang Number long java lang Long 父类为java lang Number dou
  • Maven中:可以被子模块继承的元素

    即使是长期从事 Maven 工作的开发人员也不能完全掌握聚合 多模块 和 Parent 继承的关系 在使用多模块时 子模块总要指定聚合的 pom 为
  • Linux中如何修改文件或目录的权限?

    在Linux系统中 文件权限是非常重要的一个概念 它能够决定谁可以访问文件 以及可以执行哪些操作 正确地设置文件权限可以确保系统的安全性和稳定性 那么如何设置文件权限呢 以下是详细的内容 在 Linux 系统中 可以使用 chmod 命令来
  • unity中的一些快捷键(齐)

    重命名的快捷键是F2 ALT 鼠标左键点击Hierarchy对象可以展开和收起对象的所有子物体 SHIFT 空格 可以对当前窗口进行放大缩小
  • 部署docker

    1 移除之前安装过的Docker sudo yum y remove docker docker client docker client latest docker common docker latest docker latest l
  • maven怎么引入jdom_如何在Maven项目中引入自己的jar包

    1 一般情况下jar包都可以使用pom xml来配置管理 但也有一些时候 我们项目中使用了一个内部jar文件 但是这个文件我们又没有开放到maven库中 我们会将文件放到我们项目中 以下以java工程为例随便放了个地方 2 jar包的引入和
  • RK3308 Ubuntu16.04移植

    一 概述 本章将介绍Ubuntu在RK平台上的移植以及AP配网 常用的fs为buildroot编译出来的linux文件系统 而本次则是ubuntu文件系统 系统启动后需要手动对WIFI驱动进行加载并配网 二 配置Kernel 为了支持ubu
  • python requests 爬虫--爬取HTML源码不显示正文已解决

    爬虫第一步 获取整个网页的HTML信息 源代码如下 coding UTF 8 import requests if name main target https www biqukan com 1 1094 5403177 html req
  • Wireshark过滤规则及使用方法

    前言 我看到的这篇文章是转载的 但我也不知道他是从哪转载的 o 转自 Wireshark 基本语法 基本使用方法 及包过滤规则 1 过滤IP 如来源IP或者目标IP等于某个IP 例子 ip src eq 192 168 1 107 or i
  • Java基础:常用类Compare

    Compare类 Comparable接口 自然排序 1 像String 包装类等实现了Comparable接口 重写了compareTo 方法 2 String 包装类重写了compareTo 方法后 进行了从小到大的排列 Test pu
  • Linux内核内存管理算法Buddy和Slab

    文章目录 Buddy分配器 CMA Slab分配器 总结 Buddy分配器 假设这是一段连续的页框 阴影部分表示已经被使用的页框 现在需要申请一个连续的5个页框 这个时候 在这段内存上不能找到连续的5个空闲的页框 就会去另一段内存上去寻找5
  • AtCoder Beginner Contest 169 B Multiplication 2 long long竟然不够用

    AtCoder Beginner Contest 169 比赛人数11374 比赛开始后15分钟看到A题 在比赛开始后第20分钟看到所有题 AtCoder Beginner Contest 169 B Multiplication 2 lo
  • OpenGL ES 2.0升级到3.0配置win32环境以及编译所遇bug

    安装win32平台的OpenGL ES 3 0模拟器 一 安装3 0模拟器 一般用32位的 https developer arm com products software development tools graphics devel
  • ctfshow-网络迷踪-初学再练( 一座雕像判断军事基地名称)

    ctf show 网络迷踪第4关 题目中只有一座雕像 需要根据雕像提交军事基地的名称 推荐使用谷歌识图 溯源到一篇博客 答案就在文章标题中 给了一座雕像 看样子不像是国内的风格 扔谷歌识图找找线索 访问谷歌识图 根据图片搜索 https w
  • kubernetes常见异常处理

    一 kubernetes常见Pod异常状态的处理 一 一般排查方式 无论 Pod 处于什么异常状态 都可以执行以下命令来查看 Pod 的状态 kubectl get pod
  • 拉格朗日乘数法

    拉格朗日乘数法
  • attention(注意力机制)原理和pytorch demo

    目录 说明 RNN的局限性 注意力机制原理 注意力机制实现 第一步 编码 第二步 第0次打分并解码 第三步 第1次打分并解码 Demo链接和结果分析 总结 改进 说明 demo源自吴恩达老师的课程 从tensorflow修改为pytorch