使用Pytorch的LSTM文本分类

2023-05-16

Photo by Christopher Gower on Unsplash
Christopher Gower在 Unsplash上的 照片

介绍 (Intro)

Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification, this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.

欢迎使用本教程! 本教程将教您如何在短短几分钟内构建用于文本分类的双向LSTM 。 如果您还没有签出我以前关于BERT文本分类的文章,那么本教程将包含与该文章相似的代码,但会进行一些修改以支持LSTM。 本文还提供了有关如何预处理这两篇文章中使用的数据集的说明,这是来自Kaggle 的REAL和FAKE News数据集

First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network, which belongs to a larger category of neural networks called Recurrent Neural Network (RNN). Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.

首先,什么是LSTM?为什么要使用它? LSTM代表长期短期记忆网络 ,它属于较大的神经网络类别,称为递归神经网络(RNN) 。 与香草RNN相比,它的主要优点是它具有复杂的体系结构,能够更好地处理长期依赖性,该体系结构包括三个不同的门:输入门,输出门和遗忘门。 这三个门共同操作,以决定在任意时间内在LSTM单元中要记住哪些信息和要忘记哪些信息。

LSTM Cell
LSTM电池

Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:

现在,我们对LSTM有了更多的了解,让我们集中于如何为文本分类实现它。 本教程分为以下步骤:

  1. Preprocess Dataset

    预处理数据集
  2. Importing Libraries

    导入库
  3. Load Dataset

    加载数据集
  4. Build Model

    建立模型
  5. Training

    训练
  6. Evaluation

    评价

Before we dive right into the tutorial, here is where you can access the code in this article:

在我们直接学习本教程之前,您可以在这里访问本文中的代码:

  • Preprocessing of Fake News Dataset

    假新闻数据集的预处理

  • LSTM Text Classification Google Colab

    LSTM文本分类Google Colab

步骤1:预处理数据集 (Step 1: Preprocess Dataset)

The raw dataset looks like the following:

原始数据集如下所示:

Dataset Overview
数据集概述

The dataset contains an arbitrary index, title, text, and the corresponding label.

数据集包含任意索引,标题,文本和相应的标签。

For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string function which will be used to cut each sentence to the first first_n_words words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

对于预处理,我们导入Pandas和Sklearn并定义一些变量,用于路径,训练验证和测试比率,以及trim_string函数,该函数将每个句子剪切为第一个first_n_words单词。 修剪数据集中的样本不是必需的,但是它可以为较重的模型提供更快的训练,并且通常足以预测结果。

Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words , and split the dataset according to train_test_ratio and train_valid_ratio. We save the resulting dataframes into .csv files, getting train.csv, valid.csv, and test.csv.

接下来,我们将REAL转换为0,将FAKE转换为1,将标题文本连接起来以形成新的列标题 文本 (我们使用标题和文本来确定结果),删除带有空文本的行,将每个样本修剪为first_n_words ,然后根据train_test_ratiotrain_valid_ratio分割数据集。 我们将结果数据帧保存到.csv文件中,获得train.csvvalid.csvtest.csv

步骤2:导入库 (Step 2: Importing Libraries)

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

我们导入Pytorch用于模型构建,torchText用于加载数据,matplotlib用于绘图,而sklearn用于评估。

步骤3:载入资料集 (Step 3: Load Dataset)

First, we use torchText to create a label field for the label in our dataset and a text field for the title, text, and titletext. We then build a TabularDataset by pointing it to the path containing the train.csv, valid.csv, and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).

首先,我们使用torchText为数据集中的标签创建一个标签字段,并为titletexttitletext创建一个文本字段。 然后,我们通过将TabularDataset指向包含train.csvvalid.csvtest.csv数据集文件的路径来构建它。 我们创建用于加载数据的训练迭代器,有效迭代器和测试迭代器,最后,使用训练迭代器构建词汇表(仅计算最小频率为3的令牌)。

步骤4:建立模型 (Step 4: Build Model)

We construct the LSTM class that inherits from the nn.Module. Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).

我们构造了从nn.Module继承的LSTM类。 在LSTM内部,我们构造了一个Embedding层,然后是bi-LSTM层,最后是一个完全连接的线性层。 在Forward函数中,我们将文本ID穿过嵌入层以获取嵌入,将其穿过LSTM容纳可变长度序列,从两个方向进行学习,将其穿过完全连接的线性层,最后再通过Sigmoid来获得属于FAKE的序列的概率(为1)。

步骤5:训练 (Step 5: Training)

Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.

在训练之前,我们为检查点和指标构建保存和加载功能。 对于检查点,将保存模型参数和优化器; 对于度量,可以保存火车损耗,有效损耗和全局步长,以便以后可以轻松地重建图表。

We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:

我们用10个时期训练LSTM,并在超参数设置达到最佳(最低)验证损失时保存检查点和度量。 这是训练期间的输出:

The whole training process was fast on Google Colab. It took less than two minutes to train!

在Google Colab上,整个培训过程非常快捷。 培训不到两分钟!

Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.

完成训练后,我们可以加载先前保存的指标,并输出一个图表,显示整个时间的训练损失和验证损失。

步骤6:评估 (Step 6: Evaluation)

Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.

最后,为了进行评估,我们选择了先前保存的最佳模型,并根据测试数据集对其进行了评估。 我们使用默认阈值0.5来决定何时将样本分类为FAKE。 如果模型输出大于0.5,我们将该新闻分类为FAKE;否则,将其分类为FAKE。 否则为REAL。 我们输出分类报告,指示每个类别的精度,召回率和F1得分以及整体准确性。 我们还输出混淆矩阵。

We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.

我们可以看到,使用双层Bi-LSTM,我们可以在假新闻检测任务上达到77.53%的准确性。

结论 (Conclusion)

This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification!

本教程分步说明了如何使用Pytorch为文本分类实现您自己的LSTM模型。 我们发现bi-LSTM在伪造新闻检测方面达到了可接受的准确性,但仍有改进的空间。 如果您想获得更具竞争力的性能,请查看我以前关于BERT文本分类的文章!

If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)

如果您想了解有关现代NLP和深度学习的更多信息,请确保关注我以获取即将发表的文章的更新:)

翻译自: https://towardsdatascience.com/lstm-text-classification-using-pytorch-2c6c657f8fc0

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用Pytorch的LSTM文本分类 的相关文章

随机推荐

  • 各种类型的Writable

    各种类型的Writable xff08 Text ByteWritable NullWritable ObjectWritable GenericWritable ArrayWritable MapWritable SortedMapWri
  • C++ strtok的用法

    size 61 large align 61 center strtok的用法 align size 函数原型 xff1a char strtok char s char delim 函数功能 xff1a 把字符串s按照字符串delim进行
  • 读《遇见未知的自己》笔记

    为什么我不快乐 xff1f 为什么我不能拥有自己想要的生活 xff1f 此刻屏幕前的你 是否想过 xff0c 自己为什么会出现这种情况呢 xff1f 张德芬在 遇见未知的自己 一书给出了解释 xff1a 我们人类所有受苦的根源就是来自不清楚
  • PX4飞控问题汇总

    接触PX4飞控代码一年多了 xff0c 代码都是模块化 开发起来比APM的方便 xff0c 使用过程中也出现过各种怪异问题 xff0c 用的硬件是V5 nano 和V5 43 xff0c 测试的代码版本是1 9和1 10 今天总结一下遇到过
  • Sumo 搭建交叉路口交通流仿真平台

    Sumo安装 注意事项 xff1a 需要工具的使用需要环境变量的设置 需要包含文件Sumo安装路径下的bin和tools Sumo配置文件 Sumo中项目的配置文件的组成如下所示 节点文件 图 1 节点及边的拓扑图 Node的属性主要有id
  • OpenWRT 各种烧录方式及量产(三)

    界面烧录 不更新uboot 电脑连接WIFI xff08 或者通过网线连接电脑与路由器 xff09 通过浏览器访问路由器管理界面 xff0c 进行升级 注意不要断电 xff01 xff01 xff01 xff08 断电只能通过tftp方式恢
  • 华为手机root

    首先手机已解锁 xff42 xff4c 此方法针对 华为手机 可使用 xff0c 其他手机没有测试 xff0c 但应该也可以 官方的twrp没有对mate xff19 进行配适 xff0c 可以使用奇兔 twrp 提取码 ax6d 如果你没
  • 阿里云ubuntu 16.04 Server配置方案 2 远程控制桌面

    通过远程控制 xff0c 更好的管理服务器 1 XRDP远程控制 为了更好的远程管理 xff0c linux一般情况都用VNC进行远程连接 xff0c 如 TightVNC X11VNC ReadVNC等 Xrdp 是开放原始码的远端桌面通
  • 自顶向下(top down)简介

    无论是在实际生活中还是在学术问题上 xff0c 复杂的问题比比皆是 xff0c 当我们对此类问题毫无头绪的时候 xff0c 自顶向下 xff08 top down xff09 为我们提供了一种可靠的解决方法 自顶向下法将复杂的大问题分解为相
  • SecureCRT图形界面(通过设置调用Xmanager - Passive程序)

    首先 xff0c 在服务器进行设置 如果服务器是图形化界面启动的 xff0c xhost 43 命令可以不用执行 root 64 test xhost 43 xhost unable to open display 34 34 设置disp
  • 一种GPS辅助的多方位相机的VIO——Slam论文阅读

    34 A GPS aided Omnidirectional Visual Inertial State Estimator in Ubiquitous Environments 34 论文阅读 这里写目录标题 34 A GPS aided
  • docker & LXC

    目录 一 LXC1 了解Docker的前生LXC2 LXC与docker的关系3 与传统虚拟化对比4 LXC部署4 1 安装LXC软件包和依赖包4 2 启动服务4 3 创建虚拟机 5 LXC常用命令 二 doker1 什么是docker2
  • curl命令总结

    curl no cache d Users Administrator Desktop curl 7 73 0 3 win64 mingw bin gt curl Iv http abc gkmang cn 8081 index php l
  • 使用FastJSON 对Map/JSON/String 进行互转

    前言 Fastjson是一个Java语言编写的高性能功能完善的JSON库 xff0c 由阿里巴巴公司团队开发的 1 主要特性 高性能 fastjson采用独创的算法 xff0c 将parse的速度提升到极致 xff0c 超过所有json库
  • ai面向分析_2020年面向企业的顶级人工智能平台

    ai面向分析 In the long term artificial intelligence and automation are going to be taking over so much of what gives humans
  • 回答问题人工智能源码_回答21个最受欢迎的人工智能问题

    回答问题人工智能源码 Artificial intelligence sets the stage for a new era of solutions to be made with computers It allows us to s
  • 人工智能药物设计_用AI革新药物安全

    人工智能药物设计 介绍 Introduction Advances in the life sciences have brought about a transformative impact on healthcare with lif
  • 数据集分为训练验证测试_将数据集分为训练集,验证集和测试集

    数据集分为训练验证测试 测试我们的模型 Testing Our Model Supervised machine learning algorithms are amazing tools capable of making predict
  • 深度学习 场景识别_使用深度学习进行自然场景识别

    深度学习 场景识别 Recognizing the environment in one glance is one of the human brain s most accomplished deeds While the tremen
  • 使用Pytorch的LSTM文本分类

    Photo by Christopher Gower on Unsplash Christopher Gower在 Unsplash上的 照片 介绍 Intro Welcome to this tutorial This t