Pytorch版本的BERT使用学习笔记

2023-10-29

一、Google BERT:

BERT地址:https://github.com/google-research/bert

pytorch版本的BERT:https://github.com/huggingface/pytorch-pretrained-BERT

使用要求:Python 3.5+  &  PyTorch0.4.1/1.0.0  &  pip install pytorch-pretrained-bert & 下载BERT-模型

二、BERT-模型

  • BERT-Base, Multilingual (Not recommended, use Multilingual Cased instead): 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
  • BERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters

中文模型【BERT-Base-Chinese】:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz

三、简单介绍

1)Bidirectional Encoder Representations from Transformers

  1. the first unsuperviseddeeply bidirectional system for pre-training NLP,且上下文相关。
  2. train a large model (12-layer to 24-layer Transformer) on a large corpus (Wikipedia + BookCorpus) for a long time (1M update steps)

(曾经的 Semi-supervised Sequence LearningGenerative Pre-TrainingELMo, and ULMFit 只学到了单侧信息。)

2.)学习词的表示:BERT mask 了15%的word,如:

Input: the man went to the [MASK1] . he bought a [MASK2] of milk.
Labels: [MASK1] = store; [MASK2] = gallon

3. )学习句子间信息:

Sentence A: the man went to the store .
Sentence B: he bought a gallon of milk .
Label: IsNextSentence

四、BERT的使用

4.1、两种使用方式

1)Pre-training :four days on 4 to 16 Cloud TPUs

  • BERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters

2)fine-tuning :a few hours on a GPU 

  • The fine-tuning examples which use BERT-Base should be able to run on a GPU that has at least 12GB of RAM using the hyperparameters given.

4.2、Out-of-memory 如何解决

when using a GPU with 12GB - 16GB of RAM, you are likely to encounter out-of-memory issues if you use the same hyperparameters described in the paper,调整以下参数:

  • max_seq_length: 训好的模型用512,可以调小
  • train_batch_size
  • Model type, BERT-Base vs. BERT-Large: The BERT-Large model requires more memory.
  • Optimizer: 训好的模型用Adam, requires a lot of extra memory for the m and vectors. Switching to a more memory efficient optimizer can reduce memory usage, but can also affect the results. 

4.3、Pytorch-BERT的使用

原始BERT是运行在TF上的,TF-BERT使用可参考:https://www.jianshu.com/p/bfd0148b292e

Pytorch版本BERT组成如下:

1)Eight Bert PyTorch models

  • BertModel - raw BERT Transformer model (fully pre-trained),
  • BertForMaskedLM - BERT Transformer with the pre-trained masked language modeling head on top (fully pre-trained),
  • BertForNextSentencePrediction - BERT Transformer with the pre-trained next sentence prediction classifier on top (fully pre-trained),
  • BertForPreTraining - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (fully pre-trained),
  • BertForSequenceClassification - BERT Transformer with a sequence classification head on top (BERT Transformer is pre-trained, the sequence classification head is only initialized and has to be trained),
  • BertForMultipleChoice - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is pre-trained, the multiple choice classification head is only initialized and has to be trained),
  • BertForTokenClassification - BERT Transformer with a token classification head on top (BERT Transformer is pre-trained, the token classification head is only initialized and has to be trained),
  • BertForQuestionAnswering - BERT Transformer with a token classification head on top (BERT Transformer is pre-trained, the token classification head is only initialized and has to be trained).

2)Tokenizers for BERT (using word-piece) (in the tokenization.py file):

  • BasicTokenizer - basic tokenization (punctuation splitting, lower casing, etc.),
  • WordpieceTokenizer - WordPiece tokenization,
  • BertTokenizer - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.

3)Optimizer for BERT (in the optimization.py file):

  • BertAdam - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
  • 源码中:
  • __init__(params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', betas=(0.9, 0.999), e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs)
  • lr: learning rate
  • warmup: portion of t_total for the warmup, -1  means no warmup. 【使用部分t_total热身】
  • t_total: total number of training steps for the learning rate schedule, -1  means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1【总训练步骤】
  • schedule: schedule to use for the warmup (see above).
    Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
    If `None` or `'none'`, learning rate is always kept constant. Default : `'warmup_linear'`
  • eg:使用方式:train_optimi_step = int(train_iter_num / args.batch_size) * args.epochs
    optimizer = BertAdam([param for _, param in param_optimizer], lr=args.lr, warmup=0.1, t_total=train_optimi_step)

4)Five examples on how to use BERT (in the examples folder):

  • extract_features.py - Show how to extract hidden states from an instance of BertModel,
  • run_classifier.py - Show how to fine-tune an instance of BertForSequenceClassification on GLUE's MRPC task,
  • run_squad.py - Show how to fine-tune an instance of BertForQuestionAnswering on SQuAD v1.0 and SQuAD v2.0 tasks.
  • run_swag.py - Show how to fine-tune an instance of BertForMultipleChoice on Swag task.
  • run_lm_finetuning.py - Show how to fine-tune an instance of BertForPretraining on a target text corpus.

五、BERT的使用代码

使用Pytorch版本BERT使用方式如下:

1)First prepare a tokenized input with BertTokenizer

import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# 加载词典 pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']

# 将 token 转为 vocabulary 索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 定义句子 A、B 索引
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# 将 inputs 转为 PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

2)use BertModel to get hidden states

# 加载模型 pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

# GPU & put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

# 得到每一层的 hidden states 
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
# 模型 bert-base-uncased 有12层,所以 hidden states 也有12层
assert len(encoded_layers) == 12

3)use BertForMaskedLM

# 加载模型 pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

# cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

1. BertModel 

输入: in modeling.py

  • input_ids: torch.LongTensor  [batch_size, sequence_length] with the word token indices in the vocabulary.
  • token_type_ids: optional torch.LongTensor [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a sentence A and type 1 corresponds to a sentence B token.
  • attention_mask: torch.LongTensor [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used when a batch has varying length sentences.
  • output_all_encoded_layers: controls the content of the encoded_layers output as described below. Default: True.

输出:

  • encoded_layers: 取决于output_encoded_layers 参数:
    • output_all_encoded_layers=True:
    • 输出一列 encoded-hidden-states at the end of each attention block (12 full sequences for BERT-base, 24 for BERT-large), 每个 encoded-hidden-state= [batch_size, sequence_length, hidden_size]
    • output_all_encoded_layers=False:
    • 输出最后一个attention block对应的encoded-hidden-states,1个 [batch_size, sequence_length, hidden_size]
  • pooled_output: torch.FloatTensor [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated [CLF] to train on the Next-Sentence task.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch版本的BERT使用学习笔记 的相关文章

  • 源码安装zabbix

    源码安装zabbix 1 zabbix简介 2 zabbix的主要功能 3 监控指标 4 zabbix配置文件 5 服务器端配置文件 6 客户端配置文件 7 部署zabbix 8 先安装apache 8 1 安装开发工具包 8 2 下载ap
  • python使用plt.savefig保存时图片保存不完整,四周留白太多

    问题 今天在给论文添加曲线图 使用matplotlib的savefig函数中遇到图片保存不完整 且留白太多的问题 插入到论文中很难看 左边是在pycharm里的截图 右边是通过savefig保存的图片 代码如下 保存图片 plt savef
  • 怎样在PB中实现文件的拷贝与删除 (转)

    怎样在PB中实现文件的拷贝与删除 转 more 怎样在PB中实现 文件的拷贝与删除 可借助外部 函数 可用 api函数实现 1 文件拷贝 函数声明 FUNCTION boolean CopyFileA ref string cfrom re
  • 安装及使用ceres遇到过的问题

    首先ceres安装及使用需要两个依赖库glog与gflags 这两个库安装时需要注意要么只用apt install要么只用源码编译安装 如果两个都装了用apt purge或者在 usr local lib及 usr local includ
  • 设置Unity的帧率

    问题描述 Xsens接收数据的频率是30Hz 为了接收到正确 完整的数据 要将Unity的帧率换成30帧 第一种方法 点击Editor gt Project Setting gt Quality 将VSync Count那选择Every S
  • 测试框架pytest教程(4)运行测试

    运行测试文件 pytest q test example py 会运行该文件内test 开头的测试方法 该 q quiet标志使输出保持简短 测试类 pytest的测试用例可以不写在类中 但如果写在类中 类名需要是Test开头 非Test开
  • C/C++中使用Base64编码解码(使用boost库)

    Base64是一种用64个字符来表示任意二进制数据的方法 用记事本打开exe jpg pdf这些文件时 我们都会看到一大堆乱码 因为二进制文件包含很多无法显示和打印的字符 所以 如果要让记事本这样的文本处理软件能处理二进制数据 就需要一个二
  • c++判断硬盘是否连接、修改指定盘符、读取配置文件、获取exe路径

    系统 win7 64 编译器 vs2010 mfc对话框程序 工程名 fixde 语言 c 涉及函数 GetPrivateProfileString 读取配置文件内容 类型为string GetDriveType 获取某驱动器状态 GetV
  • w5500 php,[W5500]搭建属于你的家庭网络实时监控

    图9 OV2640 SVGA模式下图像输出时序图 系统上电后 MCU配置OV2640的工作方式 在OV2640准备好图像后 VSYNC会被拉高一段时间 MCU通过PCLK上升沿中断按字节接收图像数据 接下来我们将对OV2640的初始化配置程
  • 编写测试用例的基本方法之边界值

    一般边界值分析是因为程序开发循环体时的取数可能会因为 lt lt 搞错 比如下面代码 for int i 0 i lt 100 i int j i 1 System out println 循环第 j 次 循环地做某件事情 这里的程序是循环
  • 如何点击按钮把光标定位到想要的el-input中

    1 el inpu
  • C语言-结构体面向对象编程技巧

    Keil4 C51工程网址 https yunpan 360 cn surl yrNkQSrCKyc 一 面向对象 面向对象是软件开发方法 是相对于面向过程来讲的 通过把数据与方法组织为一个整体来看待 从更高的层次来进行系统建模 更贴近事物
  • python3 数据类型归纳

    1 简介 1 1 python3 数据类型 类型 含义 示例 int 整型 1 float 浮点型 1 0 bool 布尔值 True或False complex 复数 a bj string 字符串 abc123 list 列表 a b
  • 别被骗了,win10家庭版MMC是无法创建管理单元的

    今天我的PC 是win10家庭版 无法出现gpedit msc 然后下载gpedit msc 后 又出现MMC无法创建管理单元 家庭版是没有组策略和注册表的 家庭版是没有组策略和注册表的 家庭版是没有组策略和注册表的 重要的事情说三遍 下面
  • 长轮询与长连接

    实现即时通讯主要有四种方式 它们分别是轮询 长轮询 comet 长连接 SSE WebSocket 它们大体可以分为两类 一种是在HTTP基础上实现的 包括短轮询 comet和SSE 另一种不是在HTTP基础上实现是 即WebSocket
  • Spring笔记【黑马】

    Spring day01 今日目标 掌握Spring相关概念 完成IOC DI的入门案例编写 掌握IOC的相关配置与使用 掌握DI的相关配置与使用 1 课程介绍 对于一门新技术 我们需要从为什么要学 学什么以及怎么学这三个方向入手来学习 那
  • 三点估算法评估开发工作量

    概述 开发人员在进行开发工作之前都需要给出一个工作量的评估 以便后续的工作任务可以基于该时间进行排期 大多数开发人员评估工作量主要是基于过往的工作经验拍脑袋决定 并会给自己预留出一定的Buffer时间 这样可能的问题就是过度依赖个人的开发经
  • selenium爬虫_selenium爬虫如何避免对isTrusted属性检测?

    1 前言 各位码友 有两天不见 想小码哥了没 哈哈哈 成都疫情在平静9个月之后 又死灰复燃 目前还未找到确切的源头 提醒各位成都的码友一定注意戴口罩 做好自我防护 相信有关部门的防疫措施 一起共渡难关 好了 今天咱们再继续selenium爬
  • MyBatis---缓存-提高检索效率的利器

    目录 让我们来看看官方文档 缓存 一 一级缓存 1 基本介绍 2 一级缓存 3 一级缓存失效分析 二 二级缓存 1 基本介绍 2 二级缓存快速入门 2 1快速入门 3 注意事项和使用陷阱 三 Mybatis 的一级缓存和二级缓存执行顺序 四
  • C++对象模型之内存区的使用

    对象模型是面向对象程序设计语言的一个重要方面 它会直接影响面向对象语言编写程序的运行机制及对内在的使用机制 因此了解对象模型是进行程序优化 的基础 分析一般意义上程序中的数据在内存中的分布 以及程序使用的不同种类的内存等基本的概念 了解对象

随机推荐

  • 第八章(3) 聚类:DBSCAN和簇评估

    基于密度的聚类寻找被低密度区域分离的高密度区域 传统的密度 基于中心的方法 数据集中特定点的密度通过对该点半径之内的点计数 包括本身 来估计 关键是确定半径 根据基于中心的密度进行点分类 稠密区域内的点 核心点 点的邻域由距离函数和指定半径
  • 微信小程序(十)之消息推送配置(token验证失败的解决方案)

    背景 微信小程序开发 准备使用模板消息做些事情 但是发现需要先在微信公众平台的开发 开发设置 消息推送做配置 然后我们后台人员就开始各种配置 但是一到验证token就报错 很是郁闷 然后各种排查 发现了最终原因 过程和代码如下 很多网站给出
  • datx 开启debug

    1 datax源码编译 编译 mvn U clean package assembly assembly Dmaven test skip true 2 创建mysql测试表 SET FOREIGN KEY CHECKS 0 Table s
  • mysql中 SET autocommit=0 与 START TRANSACTION 的区别

    在MySQL中 SET autocommit 0 指事务非自动提交 自此句命令执行以后 每个SQL语句或者语句块所在的事务都需要显式调用commit才能提交事务 不管autocommit 是1还是0 START TRANSACTION co
  • ssh 配置文件中 maxsessions 与 MaxStartups

    MaxStartups 同时允许几个尚未登入的联机画面 所谓联机画面就是在你ssh登录的时候 没有输入密码的阶段 如下图 maxsessions 同一地址的最大连接数 也就是同一个IP地址最大可以保持多少个链接 转载于 https blog
  • CentOS基础命令大全

    1 关机 立即关机 shutdown h now 立即关机 init 0 立即关机 telinit 0 预约时间关机 shutdown h hours minutes 取消预约关机 shutdown c 重启 shutdown r now
  • 一些关于远程仓库操作的git指令

    1 更换项目所关联的仓库 要先删除目前的远程仓库 然后再添加新的远程仓库 1 git remote rm origin 2 git remote add origin 新的仓库地址 3 git remote v 查看现在的远程仓库 4 gi
  • Hausdorff 距离

    Hausdorff 距离是描述两组点集之间相似程度的一种量度 假设有两组集合 则这两个点集之间的单向 Hausdorff 距离 其中 a b 表示 a 与 b 之间的欧氏距离 h A B 也叫前向 Hausdorff 距离 h B A 也叫
  • Android9.0 mm编译失败:ninja: error: 'xxx', needed by 'xxx', missing and no known rule to make it

    Android系统源码环境下使用mm命令单独编译某一个模块 如果该模块依赖其它模块 可能会报如下错误 解决此问题的方法就是改成mma命令编译 mma命令会构建所需要的关联模块 编译命令简单总结 mm 编译当前目录下的模块 当前目录下要有An
  • 获取成员函数地址及获取函数地址

    首先我们定义一个类Ctest 类里面包含三个不同形式的成员函数 静态成员函数statFunc 动态成员函数dynFunc 和虚拟函数virtFunc 在main函数中我们利用cout标准输出流分别输出这三个函数的地址 程序如下所示 incl
  • WebSocket协议深度解析

    WebSocket协议深度解析 1 WebSocket简介 WebSocket相比于Http协议 它有如下几个优点 支持双向通信 更灵活 更高效 可扩展性更好 支持双向通信 实时性更强 更好的二进制支持 较少的控制开销 连接创建后 ws客户
  • NER相关技术

    实体词典匹配 优点 缺点 模型原理 优点 缺点 模型输入 模型输出 实体词典匹配 模型预测两路结果是怎么合并输出的 目前我们采用训练好的CRF权重网络作为打分器 来对实体词典匹配 模型预测两路输出的NER路径进行打分 在词典匹配无结果或是其
  • 随机生成六位不重复数值

    在 Core JAVA 中有个随机生成六位不重复数值的算法 大二用过一次 今天在写 Algorithms 的练习题遇到类似的问题 特贴出 1 随机生成六位不重复的数字 2 private static int generate6BitInt
  • Mybatis的四种分页方式详解

    LIMIT关键字 mapper代码 select from tb user limit pageNo pageSize 业务层直接调用 public List findByPageInfo PageInfo info return user
  • 关于微信小程序在部分PC设备无法打开的问题

    目前为止微信小程序PC端仍处于灰度测试阶段 部分设备无法打开微信小程序 这个问题在启用分包能力后尤为明显 由于我们不能去控制用户通过PC端访问小程序的行为 仍需对PC端兼容性进行测试 下面我们来介绍测试方式 一 安装微信客户端大于3 4 5
  • 【LeetCode75】第五十四题 咒语和药水的成功对数

    目录 题目 示例 分析 代码 题目 示例 分析 题目给我们两个数组 要我们找出第一个数组中每个元素能和另一个数组的元素匹配的数量 匹配的条件是乘积大于特定的值 那么要乘积大于某个值 就需要乘数越大越好 我们可以把表示药水的数组升序排序 接着
  • aircrack-ng 介绍、功能测试及部分源码分析

    aircrack ng 介绍 功能测试及部分源码分析 实验目的 1 理清aircrack ng的总体设计框架 包括各模块的功能与联系 2 核心模块的实现原理 aircrack ng aireplay ng airodump ng 实验要求
  • C++中按引用传递参数

    C 中按引用传递参数 实参通常是通过值传递给函数的 这意味着形参接收的只是发送给它们的值的副本 它们存储在函数的本地内存中 对形参值进行的任何更改都不会影响原始实参的值 然而 有时候可能会希望一个函数能够改变正在调用中的函数 即调用它的函数
  • linux中的命令 参数 对象,Linux 系统命令、命令参数及命令对象之间,普遍应该使用()间隔? (5.0分)...

    对于实用新型和外观设计专利申请 我国专利法规定实行 是自然界的雷云直接对地面物体 建筑物 放电 它的破坏作用十分大 在标杆管理最佳企业管理实践的阶段 企业只有向同行业标杆企业学习最佳做法 才能提高企业竞争力 经济订购批量 外压容器设计采用的
  • Pytorch版本的BERT使用学习笔记

    一 Google BERT BERT地址 https github com google research bert pytorch版本的BERT https github com huggingface pytorch pretraine