我正在微调 HuggingFacefacebook/bart-large-mnli
为了满足我的需要,我使用以下参数:
training_args = TrainingArguments(
output_dir=model_directory, # output directory
num_train_epochs=30, # total number of training epochs
per_device_train_batch_size=1, # batch size per device during training - 16 - Don't go over 1, it's out of memory
per_device_eval_batch_size=2, # batch size for evaluation - 64 - Don't go over 2, it's out of memory
warmup_steps=500, # number of warmup steps for learning rate scheduler - 500
weight_decay=0.01, # strength of weight decay
)
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
trainer = Trainer(
model=model, # the instantiated ???? Transformers model to be trained
args=training_args, # training arguments, defined above
compute_metrics=compute_metrics, # a function to compute the metrics
train_dataset=train_dataset, # training dataset
eval_dataset=test_dataset # evaluation dataset
)
# Train the trainer
trainer.train()
The compute_metrics
我用的是:
import numpy as np
from datasets import Dataset, load_metric
from transformers import EvalPrediction
def compute_metrics(p: EvalPrediction):
metric_acc = load_metric("accuracy")
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.argmax(preds, axis=1)
result = {}
result["accuracy"] = metric_acc.compute(predictions=preds, references=p.label_ids)["accuracy"]
return result
但无论我使用多少训练或测试数据,或者多少个纪元,当我使用trainer.evaluate()
我得到的准确度为 0.5。
我的问题是:
- 我该如何改进它?
- 如何实施其他评估指标?例如 F1 分数。
我尝试更改(添加)指标:
def compute_metrics(p: EvalPrediction):
load_accuracy = load_metric("accuracy")
load_f1 = load_metric("f1")
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.argmax(preds, axis=1)
result = {}
result["accuracy"] = load_accuracy.compute(predictions=preds, references=p.label_ids)["accuracy"]
result["f1"] = load_f1.compute(predictions=preds, references=p.label_ids)["f1"]
return result
但后来我在运行时遇到了这个错误trainer.evaluate()
:
ValueError: pos_label=1 不是有效标签。它应该是 [0, 2] 之一
您可以参考我之前的问题以了解有关我的微调的更多详细信息here
Update:
这是我使用的标记器:
from transformers import BartTokenizerFast
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli')
正如我在其他相关问题中所述,this是我用来创建和转换数据集的
正如我上面所写的,您可以参考我的链接问题以获取有关我所有流程的更多数据,我觉得没有必要再次将所有内容放入每个问题中,如果我错了,请纠正我。