BERT 微调后得到句子级嵌入

2023-12-23

我遇到了这个page https://colab.research.google.com/github/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb#scrollTo=KVB3eOcjxxm1

1)我想获得句子级嵌入(嵌入由[CLS]token)微调完成后。我怎样才能做到呢?

2)我还注意到该页面上的代码需要花费大量时间才能返回测试数据的结果。这是为什么?与我尝试获得测试预测时相比,当我训练模型时,花费的时间更少。 从该页面上的代码来看,我没有使用以下代码块

test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None, 
                                                                       text_a = x[DATA_COLUMN], 
                                                                       text_b = None, 
                                                                       label = x[LABEL_COLUMN]), axis = 1

test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)

test_input_fn = run_classifier.input_fn_builder(
        features=test_features,
        seq_length=MAX_SEQ_LENGTH,
        is_training=False,
        drop_remainder=False)

estimator.evaluate(input_fn=test_input_fn, steps=None)

相反,我只是在整个测试数据上使用了下面的函数

def getPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]

3)我怎样才能得到预测的概率。有没有办法使用keras predict method?

update1

问题2更新- 你可以使用 20000 个训练样例进行测试吗getPrediction函数?......这对我来说需要更长的时间......甚至比在 20000 个示例上训练模型所花费的时间还要多。


1) From BERT 文档 https://aihub.cloud.google.com/p/products%2F2c1fe4d8-4ff3-4d4f-8ac4-45d445532a3b

输出字典包含:

pooled_output:整个序列的形状的池化输出 [批量大小,隐藏大小]。序列输出:每个的表示 输入序列中形状为 [batch_size, 最大序列长度,隐藏大小]。

我已经添加pooled_output对应于 CLS 向量的向量。

3) 您收到对数概率。只需申请softmax以获得正态概率。

现在剩下要做的就是模型报告它。我已经留下了日志问题,但它们不再是必要的了。

查看代码变化:

def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
                 num_labels):
  """Creates a classification model."""

  bert_module = hub.Module(
      BERT_MODEL_HUB,
      trainable=True)
  bert_inputs = dict(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids)
  bert_outputs = bert_module(
      inputs=bert_inputs,
      signature="tokens",
      as_dict=True)

  # Use "pooled_output" for classification tasks on an entire sentence.
  # Use "sequence_outputs" for token-level output.
  output_layer = bert_outputs["pooled_output"]

  pooled_output = output_layer

  hidden_size = output_layer.shape[-1].value

  # Create our own layer to tune for politeness data.
  output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "output_bias", [num_labels], initializer=tf.zeros_initializer())

  with tf.variable_scope("loss"):

    # Dropout helps prevent overfitting
    output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)

    # Convert labels into one-hot encoding
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
    # If we're predicting, we want predicted labels and the probabiltiies.
    if is_predicting:
      return (predicted_labels, log_probs, probs, pooled_output)

    # If we're train/eval, compute loss between predicted and actual label
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, predicted_labels, log_probs, probs, pooled_output)

现在在model_fn_builder()添加对这些值的支持:

  # this should be changed in both places
  (predicted_labels, log_probs, probs, pooled_output) = create_model(
    is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

  # return dictionary of all the values you wanted
  predictions = {
      'log_probabilities': log_probs,
      'probabilities': probs,
      'labels': predicted_labels,
      'pooled_output': pooled_output
  }

Adjust getPrediction()因此,最终你的预测将如下所示:

('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),  <= Probability
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
  'Negative', <= Label
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
        ...
        and this is 768-d [CLS] vector (sentence embedding).    

关于2):最终我的训练花费了大约5分钟,测试花费了大约40秒。非常合理。

UPDATE

对于 20k 样本,训练时间为 12:48,测试时间为 2:07 分钟。

对于 10k 样本,时间分别为 8:40 和 1:07。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

BERT 微调后得到句子级嵌入 的相关文章

随机推荐

  • TypeError:jquery 1.9.1 版本中的“in”操作数 obj 无效

    ajax async false type POST url url module listing projectId data ajax true success function response each response funct
  • 使图像的一部分透明

    我想在按钮上放置图像 但我希望图像的一部分是透明的 我该怎么做呢 Try the Image OpacityMask http msdn microsoft com en us library ms743320 aspx财产 您可以给它一个
  • 将按钮的可见性绑定到两个文本框的内容的最简洁方法

    我有一个Button在我的应用程序中 我已将其 功能 绑定到是否TextBox是空的 如下所示
  • 提高始终加密证书的有效性

    我正在使用 SQL Server 的始终加密功能 使用受自签名证书保护的主密钥来加密数据库中的一些列 该证书是使用 SQL 2016 的 Management Studio 创建的 并且始终默认为比颁发日期提前一年的到期日期 它存储在当前用
  • 为什么 Clang 为引用和非空指针参数生成不同的代码?

    这与为什么 GCC 不能为两个 int32 的结构生成最佳运算符 q 66263263 我在 godbolt org 上研究了这个问题的代码 并注意到了这种奇怪的行为 struct Point int x y bool nonzero pt
  • Java FileHandler 禁用日志轮转

    我正在尝试禁用日志轮换 以供文件处理程序使用 FileHandler fh new FileHandler path run log 1000000 1 false 我想要的是一个日志 为每次运行创建 我不想轮换或备份旧文件 但使用此初始化
  • 从不同的数据框中获取数据

    我有一个数据框 Name Subset Type System A00 IU00 A OP A A00 IT00 PP A B01 IT 01A PP B B01 IU OP B B03 IM 09 B LP A B03 IM03A OP
  • 从 Gecko 和 Webkit 中的选择(范围)中检索父节点

    我试图在使用使用 createLink 命令的所见即所得编辑器时添加属性 我认为取回浏览执行该命令后创建的节点是很简单的 结果 我只能在 IE 中获取这个新创建的节点 有任何想法吗 以下代码演示了该问题 底部的调试日志在每个浏览器中显示不同
  • 将 AMQ 与 Rest API 网关集成

    我正在尝试将 AMQ 与 api 网关集成 以便我可以使用 API 网关中的 AWS 资源选项将消息直接从 api 网关推送到 AMQ 并在部署 AWS ARN 时收到此错误 因为集成包含无效操作 我应该在这里使用什么操作 以便 api 网
  • Eclipse Java:“创建字段”快速修复建议的模板?

    在构造函数中 我经常分配给一个不存在的字段 然后选择 Ctrl 1 在 CurrentType 类型中创建字段 memberField 问题是我希望该字段默认为最终字段 但事实并非如此 是否有用于此快速修复的模板 谢谢 我没有看到任何明显的
  • 是否可以从 Clojure 重新定义 Java 方法?

    使用多方法 我们可以向现有的 Java 类添加方法 我的问题是是否可以从 Clojure 代码重新定义一种特定方法以及如何重新定义 例如 如果您有以下课程 public class Shape public void draw 我希望能够运
  • 弹簧批次容错能力

    我正在尝试从 csv 文件导入城市数据 某些数据可能会重复 从而引发冲突错误ERROR duplicate key value violates unique constraint city unique idx Detail Key co
  • 两个 swift 函数极大地增加了编译时间

    返回并阅读我的应用程序的构建日志后 似乎存在一个奇怪的问题 两个 相对 简单的函数都将编译时间各增加一分钟 分别为 58 秒和 53 秒 这可以在我下面的构建日志中看到 这些函数位于我的 CAAgeViewController 中 并且都引
  • 真正换行 (LF) 的转义序列

    在 C 语言中 我们有几个常见的转义序列 r对于回车符 CR 这相当于做 015 n通常被描述为换行 LF 但我知道 n 将根据 CRLF 的要求被翻译成字符串 取决于操作系统 这相当于做 015 012 特别是如果我是东阿printf o
  • C++ - 使用引用类型的模板实例化

    我听说过一些关于引用到引用问题的知识this http www comeaucomputing com iso cwg defects html 106解决 我不太熟悉 C 委员会的术语 但我理解链接中的 Moved to DR 注释意味着
  • 检查输入值时出错

    我有一个使用 readline 要求人们输入数据的函数 但我不知道确保输入的数据符合我的标准的最佳方法 我认为 if 语句可能是检查错误的最佳方法 但我不确定如何合并它们 我使用它们的尝试显然是有缺陷的 见下文 举一个简单的例子 我最可能遇
  • vim 中的 Javascript 语法高亮显示

    还有其他人发现 VIM 的 Javascript 语法突出显示效果不佳吗 我发现有时我需要滚动才能调整语法突出显示 因为有时它会神秘地删除所有突出显示 有没有任何解决方法或方法来解决这个问题 我使用的是 vim 7 1 你可能想尝试这个改进
  • JBoss 作为客户端 5.1.0.GA 存储库丢失

    就在最近 我正在新计算机上创建新的 Maven 项目 这表明 jboss 作为客户端的依赖项不再可用
  • 有没有关于 Lucene.NET 的书籍 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我在亚马逊上搜索过 但在 lucene net 上找不到书 你们在 lucene net 上找到过一本不
  • BERT 微调后得到句子级嵌入

    我遇到了这个page https colab research google com github google research bert blob master predicting movie reviews with bert on