目录
说明
RNN的局限性
注意力机制原理
注意力机制实现
第一步:编码
第二步:第0次打分并解码
第三步:第1次打分并解码
Demo链接和结果分析
总结&改进
说明
demo源自吴恩达老师的课程,从tensorflow修改为pytorch,略有不同。
RNN的局限性
原始数据是一个字符串:friday august 17 2001,长度是21(包含空格),为了简便这里把每一个字符用一个onehot向量表示。于是数据转化为21个onehot向量。依次输入到一个RNN网络(可以是普通RNN、也可以是LSTM和GRU),最终得到一个向量(即RNN网络中的隐状态)。如果此时用这个向量作为整个字符串的编码信息直接去解码,很可能会丢失一些信息,尤其是输入更长的字符串时,更容易丢失信息。并且很难抽取距离较远的两个特征之间的关系。
注意力机制原理
我们的目标是把这个字符串翻译成2001-08-17。想象一下如果是人来进行这个翻译,那么我们会做出如下映射关系,箭头即表示人的注意力机制。神经网络的注意力机制就是在模仿人类。
![](https://img-blog.csdnimg.cn/20200921183907961.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl8zOTIyODM4MQ==,size_16,color_FFFFFF,t_70)
注意力机制实现
第一步:编码
由于输入序列是不定长,为方便计算,将全部输入都补充到长为30,补充方式为末尾加特定字符,记为<pad>。即friday august 17 2001<pad><pad><pad><pad><pad><pad><pad><pad><pad>。
然后把对应的30个onehot向量(对于其他任务,可以是不同的特征向量),依次输入到encoder网络(这里使用双向LSTM)中,每次计算得到的隐状态向量全都保存下来,一共是30个(这里LSTM的隐状态向量长度设为64,由于是双向LSTM,长度一共是128),作为初始特征,记作Feature_30x128,这里30表示时间序列长度。
第二步:第0次打分并解码
此时解码部分RNN网络的隐状态向量H初始为全零(这里向量长度是64),复制30份,然后和Feature_30x128拼起来得到Feature_30x192。然后输入到一个全连接网络,输出是30*1维矩阵,即长为30的向量,最后经过softmax,得到30个打分(softmax的目的是让30个打分之和为1)。
此时有30个长为128的初始特征,即Feature_30x128;以及30个打分,相乘后加起来,得到一个128维的打分后特征,此操作举例如下(为简便,例子中的特征维度不是30*128,是3*4,则分数有3个)。
![{\color{Red} {\color{Red} }Feature\_30x128}=\begin{bmatrix} 0.1 & 1.1 &0.7 \\ 0.2 & 0.5 &1.4 \\ 0.4 & 0.3 &0.5 \\ 0.3 & 0.6 &0.2 \end{bmatrix},score=\begin{bmatrix} 0.1 & 0.7 & 0.2 \end{bmatrix}](https://private.codecogs.com/gif.latex?%7B%5Ccolor%7BRed%7D%20%7B%5Ccolor%7BRed%7D%20%7DFeature%5C_30x128%7D%3D%5Cbegin%7Bbmatrix%7D%200.1%20%26%201.1%20%260.7%20%5C%5C%200.2%20%26%200.5%20%261.4%20%5C%5C%200.4%20%26%200.3%20%260.5%20%5C%5C%200.3%20%26%200.6%20%260.2%20%5Cend%7Bbmatrix%7D%2Cscore%3D%5Cbegin%7Bbmatrix%7D%200.1%20%26%200.7%20%26%200.2%20%5Cend%7Bbmatrix%7D)
相乘后如下。
![{\color{Red} Feature\_30x128}=\begin{bmatrix} 0.01 & 0.77 &0.14 \\ 0.02 & 0.35 &0.28 \\ 0.04 & 0.21 &0.1 \\ 0.03 & 0.42 &0.04 \end{bmatrix}](https://private.codecogs.com/gif.latex?%7B%5Ccolor%7BRed%7D%20Feature%5C_30x128%7D%3D%5Cbegin%7Bbmatrix%7D%200.01%20%26%200.77%20%260.14%20%5C%5C%200.02%20%26%200.35%20%260.28%20%5C%5C%200.04%20%26%200.21%20%260.1%20%5C%5C%200.03%20%26%200.42%20%260.04%20%5Cend%7Bbmatrix%7D)
然后沿着时间维度相加,得到
![{\color{Blue} Feature\_128}=\begin{bmatrix} 0.92 \\ 0.65 \\ 0.35 \\ 0.49 \end{bmatrix}](https://private.codecogs.com/gif.latex?%7B%5Ccolor%7BBlue%7D%20Feature%5C_128%7D%3D%5Cbegin%7Bbmatrix%7D%200.92%20%5C%5C%200.65%20%5C%5C%200.35%20%5C%5C%200.49%20%5Cend%7Bbmatrix%7D)
将Feature_128输入到解码部分RNN网络,只向前传播一次,得到新的输出隐状态H,然后在经过一层全连接,进行分类(即输出哪个字符)。
第三步:第1次打分并解码
往后的解码和第二步都一样,只不过H在不断变化,用以拼在Feature_30x128上,指导如何打分。
Demo链接和结果分析
代码链接:https://github.com/zcsxll/date_trans_with_attention
我们实际上一共解码10次(因为2001-08-17这种输出格式长度固定为10),每次都会的到一个长为30的打分,即一张10*30的热图,如下图。一共10行,每一行长是30,是解码对应字符时的打分结果。
从图中可知注意力机制的效果十分明显,当解码月份0和8时,august部分的打分较大。至于年份部分的打分并非一一对应,是因为训练数据集中,一旦出现零几年,就只有二零零几年。
![](https://img-blog.csdnimg.cn/20200921233228835.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl8zOTIyODM4MQ==,size_16,color_FFFFFF,t_70)
总结&改进
编码和解码部分都采用了LSTM,其中包含C和H两个隐变量,都可以拼到Feature_30x128上进行打分计算。
输出是分类任务,但实通过实验,对于本实例,训练时采用MSE损失比采用交叉熵收敛得更快。
输入到网络中的特征不应该是单个字符,而应该是单词,例如august,应该作为一个特征向量进行操作,而不是6个。