该分数只是应用 softmax 函数后答案开始标记答案结束标记的 logits 的乘积。请看下面的例子:
管道输出:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
text = r"""
???? Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
TensorFlow 2.0 and PyTorch.
"""
question = "How many pretrained models are available in Transformers?"
question_answerer = pipeline("question-answering", model = model, tokenizer= tokenizer)
print(question_answerer(question=question, context = text))
Output:
{'score': 0.5254509449005127, 'start': 256, 'end': 264, 'answer': 'over 32+'}
不带管道:
inputs = tokenizer(question, text, add_special_tokens=True, return_tensors="pt")
outputs = model(**inputs)
首先,我们创建一个掩码,每个上下文标记为 1,否则为 0(问题标记和特殊标记)。我们使用batchencoding.sequence_ids https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.sequence_ids方法:
non_answer_tokens = [x if x in [0,1] else 0 for x in inputs.sequence_ids()]
non_answer_tokens = torch.tensor(non_answer_tokens, dtype=torch.bool)
non_answer_tokens
Output:
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, False])
我们使用此掩码将特殊标记和问题标记的 logits 设置为负无穷大,然后应用 softmax(负无穷大可防止这些标记影响 softmax 结果):
from torch.nn.functional import softmax
potential_start = torch.where(non_answer_tokens, outputs.start_logits, torch.tensor(float('-inf'),dtype=torch.float))
potential_end = torch.where(non_answer_tokens, outputs.end_logits, torch.tensor(float('-inf'),dtype=torch.float))
potential_start = softmax(potential_start, dim = 1)
potential_end = softmax(potential_end, dim = 1)
potential_start
Output:
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.0567e-04, 9.7031e-05, 1.9445e-06, 1.5849e-06, 1.2075e-07,
3.1704e-08, 4.7796e-06, 1.8712e-07, 6.2977e-08, 1.5481e-07, 8.0004e-08,
3.7896e-07, 1.6438e-07, 9.7762e-08, 1.0898e-05, 1.6518e-07, 5.6349e-08,
2.4848e-07, 2.1459e-07, 1.3785e-06, 1.0386e-07, 1.8803e-07, 8.1887e-08,
4.1088e-07, 1.5618e-07, 2.5624e-06, 1.8526e-06, 2.6710e-06, 6.8466e-08,
1.7953e-07, 3.6242e-07, 2.2788e-07, 2.3384e-06, 1.2147e-05, 1.6065e-07,
3.3257e-07, 2.6021e-07, 2.8140e-06, 1.3698e-07, 1.1066e-07, 2.8436e-06,
1.2171e-07, 9.9341e-07, 1.1684e-07, 6.8935e-08, 5.6335e-08, 1.3314e-07,
1.3038e-07, 7.9560e-07, 1.0671e-07, 9.1864e-08, 5.6394e-07, 3.0210e-08,
7.2176e-08, 5.4452e-08, 1.2873e-07, 9.2636e-08, 9.6012e-07, 7.8008e-08,
1.3124e-07, 1.3680e-06, 8.8716e-07, 8.6627e-07, 6.4750e-06, 2.5951e-07,
6.1648e-07, 8.7724e-07, 1.0796e-05, 2.6633e-07, 5.4644e-07, 1.7553e-07,
1.6015e-05, 5.0054e-07, 8.2263e-07, 2.6336e-06, 2.0743e-05, 4.0008e-07,
1.9330e-06, 2.0312e-04, 6.0256e-01, 3.9638e-01, 3.1568e-04, 2.2009e-05,
1.2485e-06, 2.4744e-06, 1.0092e-05, 3.1047e-06, 1.3597e-04, 1.5105e-06,
1.4960e-06, 8.1164e-08, 1.6534e-06, 4.6181e-07, 8.7354e-08, 2.2356e-07,
9.1145e-07, 8.8194e-06, 4.4202e-07, 1.9238e-07, 2.8077e-07, 1.4117e-05,
2.0613e-07, 1.2676e-06, 8.1317e-08, 2.2337e-06, 1.2399e-07, 6.1745e-08,
3.4725e-08, 2.7878e-07, 4.1457e-07, 0.0000e+00]],
grad_fn=<SoftmaxBackward>)
现在可以使用这些概率来提取答案的开始和结束标记并计算答案分数:
answer_start = torch.argmax(potential_start)
answer_end = torch.argmax(potential_end)
answer = tokenizer.decode(inputs.input_ids.squeeze()[answer_start:answer_end+1])
print(potential_start.squeeze()[answer_start])
print(potential_end.squeeze()[answer_end])
print(potential_start.squeeze()[answer_start] *potential_end.squeeze()[answer_end])
print(answer)
Output:
tensor(0.6026, grad_fn=<SelectBackward>)
tensor(0.8720, grad_fn=<SelectBackward>)
tensor(0.5255, grad_fn=<MulBackward0>)
over 32 +
P.S.:请记住,此答案不涵盖任何特殊情况(结束令牌在开始令牌之前)。