Huggingface Transformer 问题答案置信度得分

2023-12-31

我们如何从huggingface转换器问题答案的示例代码中获取答案置信度得分?我看到管道确实返回了分数,但是下面的核心也可以返回置信度分数吗?

from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
import tensorflow as tf

tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = TFAutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

text = r"""
???? Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
TensorFlow 2.0 and PyTorch.
"""

questions = [
    "How many pretrained models are available in Transformers?",
    "What does Transformers provide?",
    "Transformers provides interoperability between which frameworks?",
]

for question in questions:
    inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="tf")
    input_ids = inputs["input_ids"].numpy()[0]

    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    answer_start_scores, answer_end_scores = model(inputs)

    answer_start = tf.argmax(
        answer_start_scores, axis=1
    ).numpy()[0]  # Get the most likely beginning of answer with the argmax of the score
    answer_end = (
        tf.argmax(answer_end_scores, axis=1) + 1
    ).numpy()[0]  # Get the most likely end of answer with the argmax of the score
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    print(f"Question: {question}")
    print(f"Answer: {answer}\n")

代码摘自https://huggingface.co/transformers/usage.html https://huggingface.co/transformers/usage.html


该分数只是应用 softmax 函数后答案开始标记答案结束标记的 logits 的乘积。请看下面的例子: 管道输出:

import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")

text = r"""
???? Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
TensorFlow 2.0 and PyTorch.
"""

question = "How many pretrained models are available in Transformers?"

question_answerer = pipeline("question-answering", model = model, tokenizer= tokenizer)

print(question_answerer(question=question, context = text))

Output:

{'score': 0.5254509449005127, 'start': 256, 'end': 264, 'answer': 'over 32+'}

不带管道:

inputs = tokenizer(question, text, add_special_tokens=True, return_tensors="pt")
outputs = model(**inputs)

首先,我们创建一个掩码,每个上下文标记为 1,否则为 0(问题标记和特殊标记)。我们使用batchencoding.sequence_ids https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.sequence_ids方法:

non_answer_tokens = [x if x in [0,1] else 0 for x in inputs.sequence_ids()]
non_answer_tokens = torch.tensor(non_answer_tokens, dtype=torch.bool)
non_answer_tokens

Output:

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False])

我们使用此掩码将特殊标记和问题标记的 logits 设置为负无穷大,然后应用 softmax(负无穷大可防止这些标记影响 softmax 结果):

from torch.nn.functional import softmax

potential_start = torch.where(non_answer_tokens, outputs.start_logits, torch.tensor(float('-inf'),dtype=torch.float))
potential_end = torch.where(non_answer_tokens, outputs.end_logits, torch.tensor(float('-inf'),dtype=torch.float))

potential_start = softmax(potential_start, dim = 1)
potential_end = softmax(potential_end, dim = 1)
potential_start

Output:

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.0567e-04, 9.7031e-05, 1.9445e-06, 1.5849e-06, 1.2075e-07,
         3.1704e-08, 4.7796e-06, 1.8712e-07, 6.2977e-08, 1.5481e-07, 8.0004e-08,
         3.7896e-07, 1.6438e-07, 9.7762e-08, 1.0898e-05, 1.6518e-07, 5.6349e-08,
         2.4848e-07, 2.1459e-07, 1.3785e-06, 1.0386e-07, 1.8803e-07, 8.1887e-08,
         4.1088e-07, 1.5618e-07, 2.5624e-06, 1.8526e-06, 2.6710e-06, 6.8466e-08,
         1.7953e-07, 3.6242e-07, 2.2788e-07, 2.3384e-06, 1.2147e-05, 1.6065e-07,
         3.3257e-07, 2.6021e-07, 2.8140e-06, 1.3698e-07, 1.1066e-07, 2.8436e-06,
         1.2171e-07, 9.9341e-07, 1.1684e-07, 6.8935e-08, 5.6335e-08, 1.3314e-07,
         1.3038e-07, 7.9560e-07, 1.0671e-07, 9.1864e-08, 5.6394e-07, 3.0210e-08,
         7.2176e-08, 5.4452e-08, 1.2873e-07, 9.2636e-08, 9.6012e-07, 7.8008e-08,
         1.3124e-07, 1.3680e-06, 8.8716e-07, 8.6627e-07, 6.4750e-06, 2.5951e-07,
         6.1648e-07, 8.7724e-07, 1.0796e-05, 2.6633e-07, 5.4644e-07, 1.7553e-07,
         1.6015e-05, 5.0054e-07, 8.2263e-07, 2.6336e-06, 2.0743e-05, 4.0008e-07,
         1.9330e-06, 2.0312e-04, 6.0256e-01, 3.9638e-01, 3.1568e-04, 2.2009e-05,
         1.2485e-06, 2.4744e-06, 1.0092e-05, 3.1047e-06, 1.3597e-04, 1.5105e-06,
         1.4960e-06, 8.1164e-08, 1.6534e-06, 4.6181e-07, 8.7354e-08, 2.2356e-07,
         9.1145e-07, 8.8194e-06, 4.4202e-07, 1.9238e-07, 2.8077e-07, 1.4117e-05,
         2.0613e-07, 1.2676e-06, 8.1317e-08, 2.2337e-06, 1.2399e-07, 6.1745e-08,
         3.4725e-08, 2.7878e-07, 4.1457e-07, 0.0000e+00]],
       grad_fn=<SoftmaxBackward>)

现在可以使用这些概率来提取答案的开始和结束标记并计算答案分数:

answer_start = torch.argmax(potential_start)
answer_end = torch.argmax(potential_end)
answer = tokenizer.decode(inputs.input_ids.squeeze()[answer_start:answer_end+1])

print(potential_start.squeeze()[answer_start])
print(potential_end.squeeze()[answer_end])
print(potential_start.squeeze()[answer_start] *potential_end.squeeze()[answer_end])
print(answer)

Output:

tensor(0.6026, grad_fn=<SelectBackward>)
tensor(0.8720, grad_fn=<SelectBackward>)
tensor(0.5255, grad_fn=<MulBackward0>)
over 32 +

P.S.:请记住,此答案不涵盖任何特殊情况(结束令牌在开始令牌之前)。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Huggingface Transformer 问题答案置信度得分 的相关文章

随机推荐

  • spring 4.0.0 与 hibernate 4.30 的兼容性[重复]

    这个问题在这里已经有答案了 我在用春季4 0 0释放罐子 休眠科 4 3 0罐子在我的spring hibernate项目 我遇到一个错误org hibernate engine FilterDefinition没有找到 实际上 在旧的 h
  • axios 内部 for 循环

    我正在尝试在 for 循环内发送 axios 请求 但循环甚至在 axios 之前就已完成 以下是我的代码 let findEmail async gt for var i 0 i lt csvData length i axios pos
  • 如何在 Swift 中组合两个 Dictionary 实例?

    我如何附加一个Dictionary到另一个Dictionary使用斯威夫特 我正在使用AlamoFire将 JSON 内容发送到的库REST https en wikipedia org wiki Representational stat
  • 如何异步使用 Tornado 和 Redis?

    我正在尝试如何异步使用 Redis 和 Tornado 我找到了龙卷风 redis https github com leporo tornado redis但我需要的不仅仅是添加一个yield在代码中 我有以下代码 import redi
  • 如何用图像的像素创建图表?

    现在 我有一个图像 我想生成一个加权图 G V E 其中 V 是顶点集 E 是边集 图像中的每个像素作为图中的节点 但我不知道该怎么做 有人可以帮助我吗 最好是蟒蛇 非常感谢 问题补充 很抱歉我对问题的描述不够清楚 我的目标是使用图像的像素
  • 以编程方式设置 Jetty 配置以增加允许的 URL 长度

    我们使用嵌入式 Jetty 9 3 1 v20150714 并遇到了problem https stackoverflow com q 19549163 421049其中我们的长查询 URL 与其他标头相结合 比允许的长度要长 The so
  • 如何获取/设置 Firebase Cloud Functions v1 环境变量

    我以旧方式使用 Cloud Functions 包中的环境变量functions config 命令 但由于我更新到 v1 0 2 即使通过JSON parse process env FIREBASE CONFIG 就像文档告诉我的那样
  • 从机上的 Jenkins 工作空间路径不一致

    我们设置了一些共享工作空间的工作 各个分支的工作流程是 构建一个名为的大型 C 项目foo 执行多个下游测试 每个测试都使用foo 我们通过分配Use custom workspace构建工作区的下游作业领域 最近 我们选择了一个分支 并将
  • 在 iPhone X 上的 ARKit ARSession 期间从前置摄像头录制视频

    我正在使用一个ARSession结合一个ARFaceTrackingConfiguration来追踪我的脸 同时 我想从 iPhone X 的前置摄像头录制视频 为此 我使用AVCaptureSession但当我开始录音时ARSession
  • MongoDB 复合键

    我刚刚开始使用 MongoDb 我注意到我得到了很多重复的条目记录 而我本打算是唯一的 我想知道如何对我的数据使用复合键 并且我正在寻找有关如何创建它们的信息 最后 我使用 Java 来访问 mongo 和 morphia 作为我的 ORM
  • $(document).ready() 未针对 ajax 加载的内容执行

    在返回的部分页面中不执行 document ready 的原因可能是什么 两次工作正常 但第三次更新 html 后没有任何反应 alert html alert PopUpItem PopUpItem html html alert in
  • 在 Woocommerce 结帐中为特定选定的运输选项添加正文类别

    如果访问者在 Woocommerce 结账页面上处于特定的送货选项中 我会尝试向页面正文添加一个类 我已经完成了以下操作 但它没有添加课程 有人可以帮忙吗 add filter body class bbloomer wc product
  • 在使用 Transform 旋转时调整 UIView 的大小

    When my UIView使用变换属性旋转 CGAffineTransformMakeRotation 我需要拖动它的一个角 例如右下角 来调整它的大小 在此过程中 当用户拖动角时 视图的角必须跟随用户的手指 并通过增加 2 个边 右下角
  • 什么时候只需要 PartialEq 而不需要 Eq 比较合适?

    我在读铁锈书 https doc rust lang org book appendix 03 derivable traits html并尝试了解用例PartialEq and Eq特征 我意识到PartialEq适用于不一定是自反的关系
  • 动态代理和检查异常

    如何让我的动态代理抛出已检查的异常 我需要一个透明的接口包装器 它有时会抛出已检查的异常 例如IOException 没有第 3 方 AOP 或编写我自己的代理是否可以 手动修改接口的所有 20 个方法也不是一个选择 正如康拉德上面提到的
  • 进入设置屏幕

    我想从我的应用程序中打开 设置 gt 声音和显示 gt 电话铃声 屏幕 我怎样才能做到这一点 根据您的需要 有几种选择可以从您的应用程序中调出 铃声 设置屏幕 如果您想调出通常可通过系统设置访问的实际首选项屏幕 让用户通过应用程序修改手机的
  • 文件上传“multipart/form”异常 org.apache.commons.fileupload.FileUploadBase$InvalidContentTypeException

    我尝试使用 Apache Commons 进行文件上传 但抛出以下异常 org apache commons fileupload FileUploadBase InvalidContentTypeException 请求不包含 multi
  • Dropwizard 中的自定义 Jetty 过滤器

    我正在尝试在 Dropwizard 实例中添加自定义标头过滤器 以检查请求的版本是否与 Dropwizard 实例的版本同步 我看到你可以使用FilterBuilder添加码头CrossOriginFilters 但是 我无法弄清楚如何设置
  • 片段 popbackstack 行为在 25.1.0 和 25.1.1 中被破坏

    自从支持版本 25 1 0 和最新的 25 1 1 以来 我在片段替换 添加方面遇到了奇怪的行为 25 1 0 已报告问题Android fragmentTransaction replace 不适用于支持库 25 1 0 https st
  • Huggingface Transformer 问题答案置信度得分

    我们如何从huggingface转换器问题答案的示例代码中获取答案置信度得分 我看到管道确实返回了分数 但是下面的核心也可以返回置信度分数吗 from transformers import AutoTokenizer TFAutoMode