HuggingFace 评估微调的零样本模型

2023-12-14

我正在微调 HuggingFacefacebook/bart-large-mnli为了满足我的需要,我使用以下参数:

training_args = TrainingArguments(
    output_dir=model_directory,      # output directory
    num_train_epochs=30,              # total number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training - 16 - Don't go over 1, it's out of memory
    per_device_eval_batch_size=2,   # batch size for evaluation - 64 - Don't go over 2, it's out of memory
    warmup_steps=500,                 # number of warmup steps for learning rate scheduler - 500
    weight_decay=0.01,               # strength of weight decay
)

model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")

trainer = Trainer(
    model=model,                          # the instantiated ???? Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    compute_metrics=compute_metrics,      # a function to compute the metrics
    train_dataset=train_dataset,          # training dataset
    eval_dataset=test_dataset             # evaluation dataset
)

# Train the trainer
trainer.train()

The compute_metrics我用的是:

import numpy as np
from datasets import Dataset, load_metric
from transformers import EvalPrediction

def compute_metrics(p: EvalPrediction):
  metric_acc = load_metric("accuracy")
  preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
  preds = np.argmax(preds, axis=1)
  result = {}
  result["accuracy"] = metric_acc.compute(predictions=preds, references=p.label_ids)["accuracy"]
  return result

但无论我使用多少训练或测试数据,或者多少个纪元,当我使用trainer.evaluate()我得到的准确度为 0.5。

我的问题是:

  1. 我该如何改进它?
  2. 如何实施其他评估指标?例如 F1 分数。

我尝试更改(添加)指标:

def compute_metrics(p: EvalPrediction):
  load_accuracy = load_metric("accuracy")
  load_f1 = load_metric("f1")
  preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
  preds = np.argmax(preds, axis=1)
  result = {}
  result["accuracy"] = load_accuracy.compute(predictions=preds, references=p.label_ids)["accuracy"]
  result["f1"] = load_f1.compute(predictions=preds, references=p.label_ids)["f1"]
  return result

但后来我在运行时遇到了这个错误trainer.evaluate():

ValueError: pos_label=1 不是有效标签。它应该是 [0, 2] 之一


您可以参考我之前的问题以了解有关我的微调的更多详细信息here


Update:

这是我使用的标记器:

from transformers import BartTokenizerFast
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli')

正如我在其他相关问题中所述,this是我用来创建和转换数据集的

正如我上面所写的,您可以参考我的链接问题以获取有关我所有流程的更多数据,我觉得没有必要再次将所有内容放入每个问题中,如果我错了,请纠正我。


0.5 并不是一个令人满意的准确度分数。

回答你的第一个问题。如何改进呢?

正如您所提到的,您已经尝试增加纪元数和批量大小。您可以尝试使用不同的优化器而不是 AdamW 进行训练,并使用不同的权重衰减。

尝试使用 SGD 或 Adagrad。

  • 与 AdamW 相比,它们需要调整的超参数更少。这可以使它们更容易调整,并且对于不同的数据集和架构更加稳健。

  • 通过根据历史梯度调整每个参数的学习率,可以帮助模型收敛到更好的解决方案。这在损失情况复杂且模型需要遍历许多局部最小值以找到全局最小值的任务中特别有用。

      # for using SGD
      from transformers import AdamW, get_linear_schedule_with_warmup, SGD
    
      training_args = TrainingArguments(
          output_dir=model_directory,      
          num_train_epochs=30,              
          per_device_train_batch_size=1,  
          per_device_eval_batch_size=2,   
          warmup_steps=500,                 
          weight_decay=0.01,              
          learning_rate=0.01,              
          optimizer_type=SGD,             
          optimizer_params={"momentum": 0.9}  # specify the optimizer hyperparameters
      )
    

这里的“动量”用于通过将先前更新的一小部分添加到模型参数的当前更新来加速优化算法的收敛。减少值以找到最合适的值。

 # For using Adagrad
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=model_directory,      
    num_train_epochs=30,              
    per_device_train_batch_size=1,  
    per_device_eval_batch_size=2,   
    warmup_steps=500,                 
    weight_decay=0.01,               
    learning_rate=0.01,              
    optimizer_type="Adagrad",        
    optimizer_params={"initial_accumulator_value": 0.1}  # specify the optimizer hyperparameters
)

这里'initial_accumulator_value'是每个参数的历史梯度累加器的初始值。这是每个参数的梯度平方的运行和,用于在训练期间调整每个参数的学习率。尝试改变值。

回答第二个问题。

尝试将平均参数添加为“宏”以计算多类分类的 F1 分数。我相信您的数据集标签编码时存在一些错误。如果您正在进行二元分类,则 0 或 2 是 pos_label 的有效值,具体取决于您想要将其视为正标签。

# for Multi-class classification
import numpy as np
from datasets import Dataset, load_metric
from transformers import EvalPrediction

def compute_metrics(p: EvalPrediction):
  metric_acc = load_metric("accuracy")
  metric_f1 = load_metric("f1")
  preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
  preds = np.argmax(preds, axis=1)
  result = {}
  result["accuracy"] = metric_acc.compute(predictions=preds, references=p.label_ids)["accuracy"]
  result["f1"] = metric_f1.compute(predictions=preds, references=p.label_ids, average='macro')["f1"]
  return result
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

HuggingFace 评估微调的零样本模型 的相关文章

  • 调整添加的绘制组件的大小和奇怪的摆动行为

    这个问题困扰了我好几天 我正在制作一个特殊的绘画程序 我制作了一个 JPanel 并添加了使用 Paint 方法绘制的自定义 jComponent 问题是 每当我调整窗口大小时 所有添加的组件都会 消失 或者只是不绘制 因此我最终会得到一个
  • 在 RESTful Web 服务中实现注销

    我正在开发一个需要注销服务的移动应用程序 登录服务是通过数据库验证来完成的 现在我陷入了注销状态 退一步 您没有提供有关如何在应用程序中执行身份验证的详细信息 并且很难猜测您在做什么 但是 需要注意的是 在 REST 应用程序中 不能有会话
  • Antlr 解析器运算符优先级

    考虑以下语法 我对运算符优先级有疑问 例如 res 2 a b有一个类似的解析树res 2 a b 我知道问题出在哪里 但我没有想到没有相互左递归的 漂亮 解决方案 你能帮我一点忙吗 该语法与自定义访问者一起使用 grammar Math
  • 仅当显式选择行时才关闭 ui-bootstrap typeahead

    我创建了这个jsBin http jsbin com livuqafe 2 edit来证明我遇到的问题 如果您转到此处 请尝试输入 五 并继续 你的自然反应是输入 五 然后按 Tab 如果你想要 五百 你可以向下箭头一次 但是 在这种情况下
  • 如何通过索引访问 JSON 对象中的字段

    我知道这不是最好的方法 但我别无选择 我必须通过索引访问 JSONObject 中的项目 访问对象的标准方法是只写this objectName or this objectName 我还找到了一种获取 json 对象内所有字段的方法 fo
  • 测量窗口偏移

    有没有一种方法可以测量 jQuery 中窗口的偏移量 以便我可以比较 固定 元素和相对定位元素的位置 我需要能够知道窗口滚动了多远 以便我可以使用该图来计算固定元素的高度 相对于视口顶部 和相对对象的高度 相对于顶部 之间的差异文件的内容
  • MySQL 查询计算上个月

    我想计算上个月的订单总额 我收到了从当前日期获取当月数据的查询 SELECT SUM goods total AS Total Amount FROM orders WHERE order placed date gt date sub c
  • PrimeFaces 对话框参考父级

    我有一个 xhtml 页面 显示带有条目的数据表 我还有一个用于插入新条目的按钮 该按钮显示一个包含表单的对话框 插入表格用作
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data
  • Mono 应用程序在非阻塞套接字发送时冻结

    我在 debian 9 上的 mono 下运行一个服务器应用程序 大约有 1000 2000 个客户端连接 并且应用程序经常冻结 CPU 使用率达到 100 我执行 kill QUIT pid 来获取线程堆栈转储 但它总是卡在这个位置
  • php 数组中出现意外的 json 输出结构

    我正在尝试转换动态数据 如何从 PHP 获取此 JSON JSON 122240cb 253c 4046 adcd ae81266709a6 item 0 3 这就是我所做的 但它不起作用 PHP json array 122240cb 2
  • 将第三个表链接到多对多关联中的桥接表

    设计这个数据库的正确方法是什么 这是我设置表格的方式 我在名为 教师 的表和名为 仪器 的表之间存在多对多关系 然后我有一个连接两者的桥接表 我想将另一个表与 BRIDGE 表关联起来 意思是乐器 老师的组合 该表有 3 行 指定老师可以教
  • Amazon RDS for SQL Server 是否支持 SSIS?

    从谷歌搜索中读到一些相互矛盾的答案 不确定答案是是 否还是可能 我觉得读的时候已经很清楚了this http docs aws amazon com AmazonRDS latest UserGuide CHAP SQLServer htm
  • 如何确定 CultureInfo 实例是否支持拉丁字符

    是否可以确定是否CultureInfo http msdn microsoft com en us library system globalization cultureinfo aspx我正在使用的实例是否基于拉丁字符集 我相信你可以使
  • 如何在 JFreeChart 中设置多个系列的线条粗细?

    我创建了很多图表 在他们每个人中我都需要打电话 renderer setSeriesStroke i new BasicStroke 2 0f 对于每个系列 renderer is chart getXYPlot getRenderer 我
  • 如何在 OSX 上安装 LaTeX .sty 文件?

    我设置了一个 LaTeX 项目 tex documents some file tex support todonotes sty where some file tex uses todonotes usepackage colorinl
  • 使用 WGL 创建现代 OpenGL 上下文?

    我正在尝试使用 Windows 函数创建 OpenGL 上下文 现代版本 基本上代码就是 创建窗口类 注册班级 创建一个窗口 choose PIXELFORMATDESCRIPTOR并设置它 创建旧版 OpenGL 上下文 使上下文成为当前
  • Android 材料芯片组件崩溃应用程序。无法膨胀 xml

    Tried Chip来自两个支持库的组件 com google android support design 28 0 0 rc01和材料 com google android material material 1 0 0 rc01 堆栈
  • 禁用允许文本选择的

    残疾人可以吗
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • Excel VBA 创建 json 有效负载

    我正在使用 Excel VBA 并调用外部 REST API 调用需要 json 格式的有效负载 我在创建 json 格式时遇到问题 customerContext identifiers apiName email value email
  • Docker 绑定安装 - 在浏览器上看不到更改

    我在 Windows Home 上使用 docker toolbox 我能够运行 jekyll serve Web 服务器映像来查看浏览器上的默认页面 但是当我尝试在 VS Code 上编辑文件时 刷新浏览器后看不到更改 知道为什么刷新后看
  • 设置 WPF 用户控件图标时无法识别 URI 前缀错误

    我正在创建一个 WPF 窗口并在其中加载用户控件 如下所示 Uri uri new Uri Views ApplicationInfo xaml UriKind RelativeOrAbsolute UserControl versionI
  • 通过搜索嵌套对象属性来过滤对象数组

    我有一个对象数组 我想通过将嵌套属性与搜索词进行比较来过滤它们 例如 var array category Business users name Sally tags tag accounting tag marketing name B
  • R 闪亮仪表板中标题中的主页按钮

    我试图在我的 Shiny 应用程序的标题中添加一个主页按钮 以便每当有人从任何选项卡单击它时 它都会重定向到第一页 目前 我在每个选项卡中使用一个actionButton 和observeEvent 返回第一页 我无法在 Shiny 应用程
  • 对同一行的并发更新

    我试图弄清楚如果我同时从不同的客户端发出以下两个查询 MySQL InnoDB 中应该发生什么 UPDATE tbl SET a a 1 WHERE id 123 UPDATE tbl SET b b 1 WHERE id 123 如果查询
  • Directx 的变化

    我的 win8 和 directx 库有问题 我有 directx jun 2010 我添加了它的 d3dx11 lib 和 h 文件 但它不起作用并说找不到库 我发现下面的 hte 链接说您可以使用 win8 sdk 而不是 direct
  • Windows 上的 Qt 5.1.0 使用 minGW 4.8 需要很长时间来调试

    我已从 qt project 下载页面下载并安装了适用于 Windows 32 位 MinGW 4 8 的 Qt 5 1 0 我已经运行了安装程序 并且能够使用这些库和 minGW 4 8 32 位工具链来编译和运行应用程序 但是 我有一个
  • 为什么文件被放置在“C:\Users\<用户名>AppData\Local\VirtualStore\Program Files(x86)”中?

    我最近更新了我的视觉基本6 0应用程序 现在包含一个 exe manifest 文件以防止UAC虚拟化 应用此更新后 一些用户找不到他们的数据文件 AccessMDB 文件 经过系统搜索后 他们最终在C Users
  • 使用指针 C++ 实现双向链表

    我目前正在自学 C 并尝试使用部分完成的指针在 C 中实现双向链表 我知道代码当前无法处理悬空节点或输出错误 接下来我将实现这两 者 但是 代码至少应该能够构造一个列表对象并向其中添加元素 目前 当我尝试调用列表的构造函数时 出现错误 该错
  • 更新slot vuejs中的数据

    你好 我在 laravel 项目中使用 vuejs 这是我的 vuejs 代码 Vue component search and select template div div
  • 即使导出后,process.env 变量也未定义

    我正在编写一个 Node js Express 应用程序 并希望使用环境变量来设置服务器应运行的端口 但是 我似乎无法得到process env PORT阅读我的PORT环境变量 我已经使用定义了 PORT 环境变量export像这样 ex
  • 如何使用 Maven 插件从带有注释的现有实体生成 DDL?

    我有 Maven 项目 我想从现有实体生成 DDL 我怎样才能做到这一点 有没有可以生成 DDL 的 Maven 插件 我正在使用JPA 打开jpa openjpa maven plugin 插件提供了一个目标sql 使用此目标 可以从现有
  • 禁用 Windows 窗体上的所有事件

    有没有办法暂时禁用 Windows 窗体上的所有事件 我遇到的情况是 辅助线程上的处理被主线程上的事件破坏 主线程事件正在修改数据绑定到辅助线程使用的变量的控件的内容 寻找一种方法来 锁定 表单 直到辅助线程上的处理完成 显然 将处理移至主
  • 非规格化向量

    如何对已标准化的向量进行反标准化以获得标准化之前的原始值 例如 vec 0 5 1 0 0 0 vec length sqrt vec x 2 vec y 2 vec z 2 vec normalized vec x vec length
  • Visual Studio 设计时属性 - 表单列表下拉菜单

    编辑 需要明确的是 我知道如何通过反射获取表单列表 我更关心设计时属性网格 我有一个具有 Form 类型公共属性的用户控件 我希望能够在设计时从下拉列表中选择一个表单 我想从一组命名空间填充表单下拉列表 UI Foo Forms 如果您拥有
  • 如何从 Web 扩展弹出 JavaScript 中知道浏览器是 Chrome 还是 Firefox?

    我正在使用chromeChrome 和 Firefox 的命名空间 但想知道哪个浏览器正在运行网络扩展 扩展资源的链接在 Chrome 和 Firefox 中具有不同的方案 const isFirefox chrome runtime ge
  • 使用 float 格式说明符打印 int 变量

    int main int a 5 float b 7 5 printf d f n a b printf d f n a a return 0 当我在 gcc 编译器中编译它时 输出是 5 7 500000 5 7 500000 但是在 V
  • SQL Server中for循环的语法

    a 的语法是什么forSQL 中的循环 没有 for 循环 只有 while 循环 DECLARE i int 0 WHILE i lt 20 BEGIN SET i i 1 do some work END
  • HuggingFace 评估微调的零样本模型

    我正在微调 HuggingFacefacebook bart large mnli为了满足我的需要 我使用以下参数 training args TrainingArguments output dir model directory out