NLP中BERT在文本二分类中的应用

2023-11-20

最近参加了一次kaggle竞赛Jigsaw Unintended Bias in Toxicity Classification,经过一个多月的努力探索,从5月20日左右到6月26日提交最终的两个kernel,在public dataset上最终排名为4%(115/3167),说实话以前也并没有怎么接触过NLP方面的东西,对深度学习的理解也不是特别深刻。
BERT是目前非常火的NLP模型,采用两段式的训练方式,分为pretrain和fine-tune两部分,pretrain部分由谷歌在TPU集群上训练完成,并给出部分模型供免费下载使用。其中包括’cased_l-24_h-1024_a-16’, ‘chinese_l-12_h-768_a-12’, ‘uncased_l-12_h-768_a-12’, ‘uncased_l-24_h-1024_a-16’, ‘multi_cased_l-12_h-768_a-12’, ‘cased_l-12_h-768_a-12’
在这里插入图片描述
github上的下载地址为
https://github.com/google-research/bert
在本次比赛中主要采用的是BERT-Base,uncased的这个model,训练全数据180万左右的样本,样本长度设为220,在四卡Titan的GPU卡上训练的时间接近6小时,BERT-Large,uncased的训练时间大概是33小时左右。uncased和cased的区别在于uncased将全部样本变为小写,而cased则要区分大小写,将cased和uncased的模型训练结果进行融合也会有一定程度的提升。
比赛的主要代码是在一个以色列大佬的kernrl上进行修改,比赛结束时已经被fork了将近1595,代码地址
https://www.kaggle.com/yuval6967/toxic-bert-plain-vanila,非常感谢这位无私的大佬

代码详解

将样本文本转为成bert-format

def convert_lines(example, max_seq_length,tokenizer):
    max_seq_length -=2
    all_tokens = []
    longer = 0
    for text in tqdm_notebook(example):
        tokens_a = tokenizer.tokenize(text)
        if len(tokens_a)>max_seq_length:
            tokens_a = tokens_a[:max_seq_length]
            longer += 1
        one_token = tokenizer.convert_tokens_to_ids(["[CLS]"]+tokens_a+["[SEP]"])+[0] * (max_seq_length - len(tokens_a))
        all_tokens.append(one_token)
    print(longer)
    return np.array(all_tokens)

应用BERT模型进行词嵌入,转换成sequences

bert_config = BertConfig('../input/bert-pretrained-models/uncased_l-12_h-768_a-12/uncased_L-12_H-768_A-12/'+'bert_config.json')
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH, cache_dir=None,do_lower_case=True)
train_df = pd.read_csv(os.path.join(Data_dir,"train.csv")).sample(num_to_load+valid_size,random_state=SEED)
train_df['comment_text'] = train_df['comment_text'].astype(str) 
sequences = convert_lines(train_df["comment_text"].fillna("DUMMY_VALUE"),MAX_SEQUENCE_LENGTH,tokenizer)
#此处可以对其进行保存,下次可以直接调用npy文件
#np.save('sequences.npy',sequences)
#sequences = np.load('../wql/wql_base_sequences.npy')

通过uncased_L-12_H-768_A-12文件下的bert_model.ckpt和bert_config.json文件,在working目录下生成pytorch_model.bin文件,并copy其中的bert_config.json到working目录下

BERT_MODEL_PATH = '../input/bert-pretrained-models/uncased_l-12_h-768_a-12/uncased_L-12_H-768_A-12/'
convert_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch(
    BERT_MODEL_PATH + 'bert_model.ckpt',
BERT_MODEL_PATH + 'bert_config.json',
WORK_DIR + 'pytorch_model.bin')
shutil.copyfile(BERT_MODEL_PATH + 'bert_config.json', WORK_DIR + 'bert_config.json')

实例化模型

train_dataset = torch.utils.data.TensorDataset(torch.tensor(X,dtype=torch.long), torch.tensor(y,dtype=torch.float))
#实例化模型,确保working目录下有bert_config.json文件和pytorch_model.bin文件,
model = BertForSequenceClassification.from_pretrained('./working/',cache_dir=None,num_labels=1)
#四卡并行计算的代码
devices = [0,1,2,3]
if len(devices) > 1:
    print('use multi gpus')
    model = nn.DataParallel(model, device_ids=devices)
    model = model.cuda(devices[0])
#定义权重衰减的参数
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
model=model.train()

设置bert训练的参数

lr=2e-5
#若batch_size大可能导致cuda memory out的问题,可以通过减低batchsize来解决,在训练large时,batchsize设过4
batch_size = 100
#用于参数更新的加速,设为2,参数更新的次数为1的1/2
accumulation_steps=2
save_steps = 1000
checkpoint = None
# 全数据训练的epoch次数
EPOCHS = 1
num_train_optimization_steps = int(EPOCHS*len(train) / batch_size / accumulation_steps)
optimizer = BertAdam(optimizer_grouped_parameters,
                     lr=lr,  #在epoch2时可适当降低lr,例如取其1/2
                     warmup=0.05,   #当lr下降的时候,也可以降低warmup,例如epoch2时设为0
                     t_total=num_train_optimization_steps)

开始进行一个epoch的全样本数据训练,batch_size为100

tq = tqdm_notebook(range(EPOCHS))
for epoch in tq:
    #将avg_loss和avg_accuracy写入txt文件
    file_name = 'loss_log_' + 'epoch' + str(epoch) + '.txt'
    file = open(file_name, 'w', encoding='utf-8')
    
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    avg_loss = 0.
    avg_accuracy = 0.
    optimizer.zero_grad()   
    for (x_batch, y_batch) in tqdm(train_loader):
        y_pred = model(x_batch.cuda(), attention_mask=(x_batch>0).cuda(), labels=None)
        loss = nn.随意一个loss-function()(y_pred, y_batch.cuda())
        loss.backward()
        optimizer.step()  # Now we can do an optimizer step
        optimizer.zero_grad()
        avg_loss += loss.item() / len(train_loader)
        avg_accuracy += torch.mean(((torch.sigmoid(y_pred[:,0])>0.5) == (y_batch[:,0]>0.5).cuda()).to(torch.float) ).item()/len(train_loader)
        i += 1
        
        file.write('batch' + str(i) + '\t' + 'avg_loss' + '=' + str(avg_loss) + '\t' + 'avg_accuracy' + '=' + str(avg_accuracy) + '\n')
    file.close()
    
    file_path = output_model_file + str(epoch) +'.bin'
    #保存bert model用于迁移学习或测试
    torch.save(model.state_dict(), file_path)
    print(f'loss:{avg_loss} accuracy:{avg_accuracy}')

应用过程中的一点理解

就我目前对bert的理解和在比赛中调参的过程来浅谈一点经验,大佬贴出了bert模型的模板,几乎所有选手都采用这位大佬的模型作为原本进行改写,有几个重要点可以进行改写来提高预测的准确率

loss-function

首先是loss-function,这决定了这次比赛的成败,如果没法写出一个适合比赛的lossfunction,那么比赛最终成绩也就是可以fork到的最高分。虽然大佬在bert fine-tune中只给出了F.binary_cross_entropy_with_logits,但是在其他LSTM的kernel上却给出了下面这样比较优秀的loss

def custom_loss(data, targets):
     bce_loss_1 = nn.BCEWithLogitsLoss(weight=targets[:,1:2])(data[:,:1],targets[:,:1])
     bce_loss_2 = nn.BCEWithLogitsLoss()(data[:,2:],targets[:,2:])
     return (bce_loss_1 * loss_weight) + bce_loss_2

可以通过一定的补充,应用到bert模型中去,仅仅是这样一步就使得bert的预测准确率提高了很多

epoch

在对全样本数据进行fine-tune的过程中发现,训练到第三个epoch的时候,bert模型就发生了过拟合,通过小样本数据多次进行lr的调整之后也得到同样的结果,因此epoch的数量设置为2,比赛结束后一些高分单个bert的epoch数量也确实是2。

参数调整

在我看来fine-tune过程中可以调节的参数应该有以下几个,lr、warmup、dropout、accumulation_steps、batch_size、weight_decay,MAX_SEQUENCE_LENGTH

hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention probabilities.

accumulation_steps减小从实践效果来说可以略微提升训练效果。
weight_decay 调整之后并没有收到准确率提升的效果
lr的话,第一个epoch为2e-5,第二个epoch为1e-5是相对较好的lr参数
batch_size对于结果影响并不大,越大的batchsize计算速度越快,但是对于gpu的显存要求更高
MAX_SEQUENCE_LENGTH 规定了sequence的长度,长度越长对显存要求也越高
warmup 一个高分的bert模型中,epoch0 lr为2e-5,warmup为0.05,epoch1 lr为1e-5,warmup为0,对结果有一定的提升
给出一高分的bert模型地址 https://www.kaggle.com/hanyaopeng/single-bert-base-with-0-94376
其实我对bert理解也只停留在模型的应用上,如有了解的大佬看到觉得不对的地方,希望指出错误。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

NLP中BERT在文本二分类中的应用 的相关文章

  • 比较文本文档含义的最佳方法?

    我正在尝试找到使用人工智能和机器学习方法来比较两个文本文档的最佳方法 我使用了 TF IDF Cosine 相似度和其他相似度度量 但这会在单词 或 n gram 级别上比较文档 我正在寻找一种方法来比较meaning的文件 最好的方法是什
  • target_vocab_size 在方法 tfds.features.text.SubwordTextEncoder.build_from_corpus 中到底意味着什么?

    根据这个链接 https www tensorflow org datasets api docs python tfds features text SubwordTextEncoder build from corpus target
  • Spacy 中的自定义句子分割

    I want spaCy使用我提供的句子分割边界而不是它自己的处理 例如 get sentences Bob meets Alice SentBoundary They play together gt Bob meets Alice Th
  • 语音识别中如何处理同音词?

    对于那些不熟悉什么是同音字 https en wikipedia org wiki Homophone是的 我提供以下示例 我们的 是 嗨和高 到 太 二 在使用时语音API https developer apple com docume
  • SpaCy 模型“en_core_web_sm”的词汇量大小

    我尝试在 SpaCy 小模型中查看词汇量 model name en core web sm nlpp spacy load model name len list nlpp vocab strings 只给了我 1185 个单词 我也在同
  • 如何对德语文本进行词形还原?

    我有一篇德语文本 我想对其应用词形还原 如果不可能进行词形还原 那么我也可以接受词干提取 Data 这是我的德语文本 mails Hallo Ich spielte am fr hen Morgen und ging dann zu ein
  • BERT 输出不确定

    BERT 输出是不确定的 当我输入相同的输入时 我希望输出值是确定性的 但我的 bert 模型的值正在变化 听起来很尴尬 同一个值返回两次 一次 也就是说 一旦出现另一个值 就会出现相同的值并重复 如何使输出具有确定性 让我展示我的代码片段
  • 使用正则表达式标记化进行 NLP 词干提取和词形还原

    定义一个函数 名为performStemAndLemma 它需要一个参数 第一个参数 textcontent 是一个字符串 编辑器中给出了函数定义代码存根 执行以下指定任务 1 对给出的所有单词进行分词textcontent 该单词应包含字
  • SpaCy 中的自定义句子边界检测

    我正在尝试在 spaCy 中编写一个自定义句子分段器 它将整个文档作为单个句子返回 我编写了一个自定义管道组件 它使用以下代码来执行此操作here https github com explosion spaCy issues 1850 但
  • NLTK:包错误?朋克和泡菜?

    基本上 我不知道为什么会收到此错误 只是为了获得更多图像 这里有一个代码格式的类似消息 由于是最新的 该帖子的答案已经在消息中提到 Preprocessing raw texts LookupError Traceback most rec
  • PHP 和 NLP:嵌套括号(解析器输出)到数组?

    想要将带有嵌套括号的文本转换为嵌套数组 以下是 NLP 解析器的输出示例 TOP S NP PRP I VP VBP love NP NP DT a JJ big NN bed PP IN of NP NNS roses 原文 我喜欢一大床
  • 如何使用动词时态/语气制作稀疏匹配器模式?

    我一直在尝试使用动词时态和情绪为 spacy 匹配器创建一个特定的模式 我发现了如何使用 model vocab morphology tag map token tag 访问使用 spacy 解析的单词的形态特征 当动词处于虚拟语气模式
  • 使用 OpenNLP 获取句子的解析树。陷入困境。

    OpenNLP 是一个关于自然语言处理的 Apache 项目 NLP 程序的目标之一是解析一个句子 并给出其语法结构的树 例如 天空是蓝色的 这句话 可能会被解析为 S NP VP The sky is blue where S是句子 NP
  • 对产品列表进行分类的算法? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个代表或多或少相同的产品的列表 例如 在下面的列表中 它们都是希捷硬盘 希捷硬盘 500Go 适用于笔记本电脑的希捷硬盘 120
  • 文本摘要评估 - BLEU 与 ROUGE

    根据两个不同摘要系统 sys1 和 sys2 的结果和相同的参考摘要 我使用 BLEU 和 ROUGE 对它们进行了评估 问题是 sys1 的所有 ROUGE 分数均高于 sys2 ROUGE 1 ROUGE 2 ROUGE 3 ROUGE
  • 如何将句子或文档转换为向量?

    我们有将单词转换为向量的模型 例如 word2vec 模型 是否存在类似的模型 可以使用为单个单词学习的向量将句子 文档转换为向量 1 跳克法 以及使用它的工具 谷歌 word2vec https code google com p wor
  • 使用自定义层运行 Keras 模型时出现问题

    我目前正在攻读学士学位论文FIIT STU https www fiit stuba sk en html page id 749 其主要目标是尝试复制和验证以下结果study http arxiv org abs 2006 00885 这
  • 在Python中表示语料库句子的一种热门编码

    我是 Python 和 Scikit learn 库的初学者 我目前需要从事一个 NLP 项目 该项目首先需要通过 One Hot Encoding 来表示一个大型语料库 我已经阅读了 Scikit learn 关于 preprocessi
  • 池化与随时间池化

    我从概念上理解最大 总和池中发生的情况作为 CNN 层操作 但我看到这个术语 随时间变化的最大池 或 随时间变化的总和池 例如 用于句子分类的卷积神经网络 https arxiv org pdf 1408 5882 pdfYoon Kim
  • 如何在 scikit-learn 的 SVM 中使用非整数字符串标签? Python

    Scikit learn 具有相当用户友好的用于机器学习的 python 模块 我正在尝试训练用于自然语言处理 NLP 的 SVM 标记器 其中我的标签和输入数据是单词和注释 例如 词性标记 而不是使用双精度 整数数据作为输入元组 1 2

随机推荐

  • 国际软件项目经理的七大素质

    国际软件项目经理的七大素质 1 在一个或多个应用领域内使用整合了道德 法律和经济问题的工程方法来设计合适的解决方案 2 懂得确定客户需求并将其转换成软件需求的过程 3 履行项目经理的职责 善于处理技术和管理方面的事务 4 懂得并使用有用的项
  • 人脸特征点检测

    CVPR2016刚刚落下帷幕 本文对面部特征点定位的论文做一个简单总结 让大家快速了解该领域最新的研究进展 希望能给读者们带来启发 CVPR2016相关的文章大致可以分为三大类 处理大姿态问题 处理表情问题 处理遮挡问题 1 姿态鲁棒的人脸
  • 描述性能测试工作中的完整过程?

    有简单接触 采用的工具是Jmeter 进行轻量级的压力测试 1 确定好压力测试的功能模块 首先用Jmeter录制脚本 然后对脚本进行优化 2 对一些数据进行参数化 利用CSV导入存在txt文档里面的数据 3 设计测试场景 4 执行压力测试
  • 如何在windows的DOS窗口中正常显示中文(UTF-8字符)

    打开CMD exe命令行窗口 通过 chcp命令改变代码页 UTF 8的代码页为65001 ANSI OEM 简体中文 GBK为936 window default OEM 美国为437 如果chcp命令得到437 那么一定不能显示中文 此
  • 无法安装vmnet8虚拟网络适配器、vmware network editor未响应、注册失败,请检查账号数据库配置是否正确的解决

    文章目录 虚拟网络适配器安装 vmware network editor未响应 注册失败 请检查账号数据库配置是否正确的解决 关于第一次安装虚拟机的 全文约 423 字 预计阅读时长 2分钟 虚拟网络适配器安装 vmware network
  • rol/ror in c++

    template
  • 20天拿下华为OD笔试之【BFS】2023Q1A-微服务的集成测试【闭着眼睛学数理化】全网注释最详细分类最全的华为OD真题题解

    BFS 2023Q1A 微服务的集成测试 题目描述与示例 题目描述 现在有 n 个容器服务 服务的启动可能有一定的依赖性 有些服务启动没有依赖 其次服务自身启动加载会消耗一些时间 给你一个 nxn 的二维矩阵 useTime 其中 useT
  • simulink仿真adc采样和epwm输出基础知识讲解

    F28027 12位ADC 2的y次方 tbclk 计数时钟的频率 tprd 一个周期内记得个数 1 tbclk 每次计一个数的时间 一个pwm周期的时间 pwm的周期 时基计数器 CRT 计数时钟由系统时钟分频来的 比较寄存器 CMR 决
  • 大数据、数据分析和数据挖掘的区别

    大数据 数据分析 数据挖掘的区别是 大数据是互联网的海量数据挖掘 而数据挖掘更多是针对内部企业行业小众化的数据挖掘 数据分析就是进行做出针对性的分析和诊断 大数据需要分析的是趋势和发展 数据挖掘主要发现的是问题和诊断 1 大数据 big d
  • 软件项目管理的平衡原则和高效原则

    1 平衡原则 在我们讨论软件项目为什么会失败时 列出了很多的原因 答案有很多 如管理问题 技术问题 人员问题等等 但是 有一个根本的问题是最容易被忽视的 也是软件系统的用户 软件开发商 销售代理商最不愿证实的 那就是 需求 资源 工期 质量
  • 计算机网络 网络层——IP数据报 详记

    IP 数据报的格式 一个 IP 数据报由首部和数据两部分组成 首部的前一部分是固定长度 共 20 字节 是所有 IP 数据报必须具有的 在首部的固定部分的后面是一些可选字段 其长度是可变的 IP数据报首部的固定部分中的各字段 版本 占4位
  • 信号量机制

    简介 信号量是一种数据结构 信号量的值与相应资源的使用情况有关 信号量的值由P V操作改变 常用信号量 整型信号量 整型信号量S的等待 唤醒机制 P V操作 wait S while S lt 0 do no op s signal S S
  • python字符串与列表

    字符串 字符串定义 输入输出 定义 切片是指对操作的对象截取其中一部分的操作 适用范围 字符串 列表 元组都支持切片操作 切片的语法 起始下标 结束 步长 字符串中的索引是从 0 开始的 最后一个元素的索引是 1 字符串的常见操作 查找 f
  • centos7搭建ftp服务器及ftp配置讲解

    ftp 即文件传输 它是INTERNET上仍然常用的最老的网络协议之一 它为系统提供了通过网络与远程服务器传输的简单方法 FTP服务器包的名称为vsftpd 一 vsftpd安装 并简单配置启动 安装 很简单 一句话 yum install
  • Socket接收数据耗时

    1 遇到问题 首先说明一下我遇到的问题 服务端传递Byte数组 长度在900w 客户端接收时会耗时10s 我的代码是这样的 2 Socket缓冲区 http t zoukankan com bigberg p 7747419 html 每个
  • 即刻掌握python格式化输出的三种方式 (o゜▽゜)o☆

    目录 1 f 转化的格式化输出方式 2 格式化输出的方法 3 format 格式化输出的方法 1 f 转化的格式化输出方式 只需要在我们要格式化输出的内容开头引号的前面加上 f 在字符串内要转义的内容用 括起来即可 模板 print f x
  • 企业微信登录-前端实现

    企业微信登录 企业微信登录 前端具体实现 下面代码中配置项的字段具体用途说明可以阅读企业微信开发者说明文档 我们通过提供的企业微信登录组件来进行站内登录 下面是我封装的登录组件以及使用方法 weChatLogin vue 封装的组件
  • hudi-hive-sync

    hudi hive sync Syncing to Hive 有两种方式 在hudi 写时同步 使用run sync tool sh 脚本进行同步 1 代码同步 改方法最终会同步元数据 但是会抛出异常 val spark SparkSess
  • spring:AOP面向切面编程+事务管理

    目录 一 Aop Aspect Oriented Programming 二 springAOP实现 1 XML实现 2 注解实现 三 spring事务管理 一 Aop Aspect Oriented Programming 将程序中的非业
  • NLP中BERT在文本二分类中的应用

    最近参加了一次kaggle竞赛Jigsaw Unintended Bias in Toxicity Classification 经过一个多月的努力探索 从5月20日左右到6月26日提交最终的两个kernel 在public dataset