【NLP实践】使用Pytorch进行文本分类——BILSTM+ATTENTION

2023-11-04

网络结构

BILSTM+ATTENTION 网络结构

代码实现

class TextBILSTM(nn.Module):
    
    def __init__(self,
                 config:TRNNConfig,
                 char_size = 5000,
                 pinyin_size = 5000):
        super(TextBILSTM, self).__init__()
        self.num_classes = config.num_classes
        self.learning_rate = config.learning_rate
        self.keep_dropout = config.keep_dropout
        self.char_embedding_size = config.char_embedding_size
        self.pinyin_embedding_size = config.pinyin_embedding_size
        self.l2_reg_lambda = config.l2_reg_lambda
        self.hidden_dims = config.hidden_dims
        self.char_size = char_size
        self.pinyin_size = pinyin_size
        self.rnn_layers = config.rnn_layers

        self.build_model()


    def build_model(self):
        # 初始化字向量
        self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size)
        # 字向量参与更新
        self.char_embeddings.weight.requires_grad = True
        # 初始化拼音向量
        self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size)
        self.pinyin_embeddings.weight.requires_grad = True
        # attention layer
        self.attention_layer = nn.Sequential(
            nn.Linear(self.hidden_dims, self.hidden_dims),
            nn.ReLU(inplace=True)
        )
        # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)

        # 双层lstm
        self.lstm_net = nn.LSTM(self.char_embedding_size, self.hidden_dims,
                                num_layers=self.rnn_layers, dropout=self.keep_dropout,
                                bidirectional=True)
        # FC层
        # self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)
        self.fc_out = nn.Sequential(
            nn.Dropout(self.keep_dropout),
            nn.Linear(self.hidden_dims, self.hidden_dims),
            nn.ReLU(inplace=True),
            nn.Dropout(self.keep_dropout),
            nn.Linear(self.hidden_dims, self.num_classes)
        )

    def attention_net_with_w(self, lstm_out, lstm_hidden):
        '''

        :param lstm_out:    [batch_size, len_seq, n_hidden * 2]
        :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]
        :return: [batch_size, n_hidden]
        '''
        lstm_tmp_out = torch.chunk(lstm_out, 2, -1)
        # h [batch_size, time_step, hidden_dims]
        h = lstm_tmp_out[0] + lstm_tmp_out[1]
        # [batch_size, num_layers * num_directions, n_hidden]
        lstm_hidden = torch.sum(lstm_hidden, dim=1)
        # [batch_size, 1, n_hidden]
        lstm_hidden = lstm_hidden.unsqueeze(1)
        # atten_w [batch_size, 1, hidden_dims]
        atten_w = self.attention_layer(lstm_hidden)
        # m [batch_size, time_step, hidden_dims]
        m = nn.Tanh()(h)
        # atten_context [batch_size, 1, time_step]
        atten_context = torch.bmm(atten_w, m.transpose(1, 2))
        # softmax_w [batch_size, 1, time_step]
        softmax_w = F.softmax(atten_context, dim=-1)
        # context [batch_size, 1, hidden_dims]
        context = torch.bmm(softmax_w, h)
        result = context.squeeze(1)
        return result

    def forward(self, char_id, pinyin_id):
        # char_id = torch.from_numpy(np.array(input[0])).long()
        # pinyin_id = torch.from_numpy(np.array(input[1])).long()

        sen_char_input = self.char_embeddings(char_id)
        sen_pinyin_input = self.pinyin_embeddings(pinyin_id)

        sen_input = torch.cat((sen_char_input, sen_pinyin_input), dim=1)
        # input : [len_seq, batch_size, embedding_dim]
        sen_input = sen_input.permute(1, 0, 2)
        output, (final_hidden_state, final_cell_state) = self.lstm_net(sen_input)
        # output : [batch_size, len_seq, n_hidden * 2]
        output = output.permute(1, 0, 2)
        # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]
        final_hidden_state = final_hidden_state.permute(1, 0, 2)
        # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)
        # atten_out = self.attention_net(output, final_hidden_state)
        atten_out = self.attention_net_with_w(output, final_hidden_state)
        return self.fc_out(atten_out)
        

Attention计算

  1. 将BILSTM网络输出的结果(shape:[batch_size, time_step, hidden_dims * num_directions(=2)])拆成两个大小为[batch_size, time_step, hidden_dims]的Tensor;
  2. 将第一步拆出的两个Tensor进行相加运算得到h(shape:[batch_size, time_step, hidden_dims]);
  3. 将BILSTM网络最后一个隐状态(shape:[batch_size, num_layers * num_directions, hidden_dims])在第二维度进行求和,得到新的lstm_hidden(shape:[batch_size, hidden_dims]);
  4. lstm_hidden的维度从[batch_size, n_hidden]扩展到[batch_size, 1, hidden_dims];
  5. 使用slef.atten_layer(h)获得用于后续计算权重的向量atten_w(shape:[batch_size, 1, hidden_dims]);
  6. h进行tanh激活,得到m(shape:[batch_size, time_step, hidden_dims]);
  7. 使用torch.bmm(atten_w, m.transpose(1, 2)) 得到atten_context(shape:[batch_size, 1, time_step]);
  8. atten_context使用F.softmax(atten_context, dim=-1)进行归一化,得到基于上下文权重的softmax_w(shape:[batch_size, 1, time_step]);
  9. 使用torch.bmm(softmax_w, h)得到基于权重的BILSTM输出context(shape:[batch_size, 1, hidden_dims]);
  10. context的第二维度消掉,得到result(shape:[batch_size, hidden_dims]) ;
  11. 返回result

模型效果

  • 1层BILSTM在训练集准确率:99.8%,测试集准确率:96.5%
  • 2层BILSTM在训练集准确率:99.9%,测试集准确率:97.3%

调参

  • dropout的值要在 0.1 以下(经验之谈,笔者在实践中发现,dropout取0.1时比dropout取0.3时在测试集准确率能提高0.5%)。

相关文章

  • 使用TextCNN模型进行文本分类(链接);
  • 使用Transformer模型进行文本分类(链接
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【NLP实践】使用Pytorch进行文本分类——BILSTM+ATTENTION 的相关文章

随机推荐

  • 【STM32】stm32是什么

    作者 一只大喵咪1201 专栏 STM32学习 格言 你只管努力 剩下的交给时间 STM32的简单介绍 介绍 特点 认识STM32 总结 介绍 STM32是一款单片机 它由意法半导体公司制造 ST是意法半导体的简称 M是指微控制器 也就是单
  • 「react进阶」年终送给react开发者的八条优化建议(篇幅较长,占用20-30分钟)

    笔者是一个 react 重度爱好者 在工作之余 也看了不少的 react 文章 写了很多 react 项目 接下来笔者讨论一下 React 性能优化的主要方向和一些工作中的小技巧 送人玫瑰 手留余香 阅读的朋友可以给笔者点赞 关注一波 公众
  • API hook 原理与Windows hook 应用

    API hook 原理与Windows hook 应用 分类 系统程序 2012 04 14 12 20 3679人阅读 评论 3 收藏 举报 hook api windows attributes descriptor winapi 目录
  • seaborn.heatmap部分参数解释

    今天也是自己用seaborn的heatmap花了一个混淆矩阵 sns heatmap ConfusoinMatrix annot True ax ax cmap Greens 发现他这个对数据倾斜的数据很不友好啊 如果有一个类别的数据超级多
  • SystemC自带example的simple_perf研习

    simple perf SystemC的性能建模示例 也是SystemC中系统级建模的一个入门简介 SystemC自带example的系列 SystemC自带example的pipe研习 SystemC自带example的pkt switc
  • fabric1.0 错误分析总结

    个人在学习 fabric1 0 项目中遇到的 一些 错误和原因总结 如发现错误即时指出 1 ERROR could not find an available non overlapping IPv4 address pool among
  • 用Rancher RKE快速部署高可用k8s集群

    用Rancher部署高可用k8s集群 用Rancher RKE部署高可用k8s集群 1 主机配置 1 1 新建虚拟主机 1 2 主机初始化配置 安装一些必要的安装包 安全设置 ntp时钟同步 内核参数调优 hostname修改 关闭swap
  • 诠释韧性增长,知乎Q3财报里的社区优势和商业化价值

    当内容平台开始做生意 往往意味着它要扮演一个大包大揽的角色 从内容的可持续性到最终变现 设计一套完整的生态系统是必需的 但并非所有平台都对此感到棘手 或者说在某些平台 生态已经不是困难 而是优势和特色 知乎就是从好平台走向好公司的典型例子
  • scrapy中使用css选择器罗列下一级的所有标签

    使用 css dl gt 即为罗列dl标签的下一级所有标签 例子 dt dl a dl dl b dl dl h1 c h1 dl dt 使用 data dt response css dl dt id all child elements
  • Python-Tkinter 图形化界面设计

    摘抄来自Python Tkinter 图形化界面设计 还是自己去看一下比较好 我只是摘抄我用的上的 一 最基本框架 from tkinter import root Tk root title 我的第一个Python窗体 root geom
  • P2524 Uim的情人节礼物·其之弐【康托展开模板题】

    题目链接 我在这里加了树状数组来优化康托展开 但是这道题的数据其实很小 不需要加也是可以的 include
  • 27 类深度学习主要神经网络

    1 感知器 Perceptron P 感知器模型也称为单层神经网络 这个神经网络只包含两层 输入层 输出层 这种类型的神经网络没有隐藏层 它接受输入并计算每个节点的加权 然后 它使用激活函数 大多数是Sigmoid函数 进行分类 应用 分类
  • 49 题目 1431: [蓝桥杯][2014年第五届真题]分糖果

    题目 1431 蓝桥杯 2014年第五届真题 分糖果 时间限制 1Sec 内存限制 128MB 提交 5807 解决 2969 题目描述 问题描述 有n个小朋友围坐成一圈 老师给每个小朋友随机发偶数个糖果 然后进行下面的游戏 每个小朋友都把
  • Python中Requests模块的异常值处理

    在我们用Python的requests模块进行爬虫时 一个简单高效的模块就是requests模块 利用get 或者post 函数 发送请求 但是在真正的实际使用过程中 我们可能会遇到网络的各种变化 可能会导致请求过程发生各种未知的错误导致程
  • Vue中的路由以及默认路由跳转

    文章目录 官方网址 Vue路由配置 安装 引入并使用 配置路由 官方网址 https router vuejs org Vue路由配置 安装 npm install vue router save 或者 cnpm install vue r
  • SpringBoot集成XxlJob分布式任务调度中心(超详细之手把手教学)

    一 前言 废话就不多说了 介绍Xxl Job的网上已经有很多 本文就不多加复制粘贴了 直接步入第二步 PS 本文包括Xxl Job分布式定时任务调度中心的搭建 以及SpringBoot集成XxlJob的全过程 如果不想了解搭建的小伙伴可以直
  • 判断加密方式

    如何判断密文的加密方式 1 如果密文是十进制 字符范围是 0 9 可以猜测是ASCII编码 2 如果密文由 a z A Z 和 构成 特别是末尾有 那么判断可能是Base64编码 Base64在线解码网址 BASE64加密解密 3 如果密文
  • Docker 部署 RocketMQ

    文章目录 安装nameserver 拉取镜像 运行容器 出现问题卸载 安装broker 创建配置文件 运行容器 出现问题卸载 安装控制台 拉取镜像 运行容器 出现问题卸载 安装nameserver 拉取镜像 docker pull rock
  • 时序预测

    时序预测 MATLAB实现ARIMA时间序列预测 armax函数 本程序基于MATLAB的armax函数实现arima时间序列预测 实现了模型趋势分析 序列差分 序列平稳化 AIC准则模型参数识别与定阶 预测结果与误差分析过程 逻辑清晰 数
  • 【NLP实践】使用Pytorch进行文本分类——BILSTM+ATTENTION

    目录 网络结构 代码实现 Attention计算 模型效果 调参 相关文章 网络结构 代码实现 class TextBILSTM nn Module def init self config TRNNConfig char size 500