scikit-learn 中的 class_weight 参数如何工作?

2023-11-29

我很难理解如何class_weightscikit-learn 的逻辑回归中的参数运行。

情况

我想使用逻辑回归对非常不平衡的数据集进行二元分类。这些类别被标记为 0(阴性)和 1(阳性),观察到的数据比例约为 19:1,大多数样本具有阴性结果。

第一次尝试:手动准备训练数据

我将拥有的数据分成不相交的数据集进行训练和测试(大约 80/20)。然后我手工对训练数据进行随机采样,得到比19:1不同比例的训练数据;从 2:1 -> 16:1。

然后,我在这些不同的训练数据子集上训练逻辑回归,并绘制召回率 (= TP/(TP+FN)) 作为不同训练比例的函数。当然,召回率是根据观察到的比例为 19:1 的不相交 TEST 样本计算的。请注意,虽然我在不同的训练数据上训练了不同的模型,但我在相同(不相交)的测试数据上计算了所有模型的召回率。

结果正如预期的那样:在 2:1 的训练比例下,召回率约为 60%,当达到 16:1 时,召回率下降得相当快。有几个比例为 2:1 -> 6:1,召回率远高于 5%。

第二次尝试:网格搜索

接下来,我想测试不同的正则化参数,因此我使用 GridSearchCV 并制作了一个由多个值组成的网格C参数以及class_weight范围。将我的 n:m 比例的负:正训练样本翻译成字典语言class_weight我以为我只是指定几个字典如下:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

我还包括None and auto.

这一次的结果完全出乎意料。对于每个值,我的所有回忆都很小(class_weight except auto。所以我只能假设我对如何设置的理解class_weight字典错了。有趣的是,class_weight对于所有值,网格搜索中“auto”的值约为 59%C,我猜它平衡为 1:1?

我的问题

  1. 你如何正确使用class_weight在训练数据中实现与实际提供的数据不同的平衡?具体来说,我传递给什么字典class_weight使用 n:m 比例的负:正训练样本?

  2. 如果你通过了各种class_weight字典到 GridSearchCV,在交叉验证期间,它会根据字典重新平衡训练折叠数据,但使用真实的给定样本比例来计算测试折叠上的评分函数吗?这一点至关重要,因为任何指标只有来自观察到的比例的数据才对我有用。

  3. 什么是auto的价值class_weight尽量按比例做?我阅读了文档,我认为“平衡数据与其频率成反比”只是意味着它使其达到 1:1。它是否正确?如果没有,有人可以澄清吗?


首先,仅仅依靠回忆可能并不好。通过将所有内容分类为正类,您可以简单地实现 100% 的召回率。 我通常建议使用 AUC 来选择参数,然后找到您感兴趣的操作点(例如给定的精度水平)的阈值。

For how class_weight有效:它会惩罚样本中的错误class[i] with class_weight[i]而不是 1。所以较高的班级权重意味着您想要更加重视某个班级。从你的说法来看,0 类的出现频率似乎是 1 类的 19 倍。所以你应该增加class_weight类 1 相对于类 0,例如 {0:.1, 1:.9}。 如果class_weight总和不等于 1,它基本上会改变正则化参数。

For how class_weight="auto"有效,你可以看看这次讨论。 在开发版本中您可以使用class_weight="balanced",这更容易理解:它基本上意味着复制较小的类,直到拥有与较大类中的样本一样多的样本,但以隐式方式进行。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

scikit-learn 中的 class_weight 参数如何工作? 的相关文章

  • 没有名为 crypto.cipher 的模块

    我现在正在尝试加密一段时间 我最近得到了这个基于 python 的密码器 名为PythonCrypter https github com jbertman PythonCrypter 我对 Python 相当陌生 当我尝试通过终端打开 C
  • Django 管理员在模型编辑时间歇性返回 404

    我们使用 Django Admin 来维护导出到我们的一些站点的一些数据 有时 当单击标准更改列表视图来获取模型编辑表单而不是路由到正确的页面时 我们会得到 Django 404 页面 模板 它是偶尔发生的 我们可以通过重新加载三次来重现它
  • SQLAlchemy 通过关联对象声明式多对多自连接

    我有一个用户表和一个朋友表 它将用户映射到其他用户 因为每个用户可以有很多朋友 这个关系显然是对称的 如果用户A是用户B的朋友 那么用户B也是用户A的朋友 我只存储这个关系一次 除了两个用户 ID 之外 Friends 表还有其他字段 因此
  • 使 django 服务器可以在 LAN 中访问

    我已经安装了Django服务器 可以如下访问 http localhost 8000 get sms http 127 0 0 1 8000 get sms 假设我的IP是x x x x 当我这样做时 从同一网络下的另一台电脑 my ip
  • Flask 会话变量

    我正在用 Flask 编写一个小型网络应用程序 当两个用户 在同一网络下 尝试使用应用程序时 我遇到会话变量问题 这是代码 import os from flask import Flask request render template
  • 从字符串中删除识别的日期

    作为输入 我有几个包含不同格式日期的字符串 例如 彼得在16 45 我的生日是1990年7月8日 On 7 月 11 日星期六我会回家 I use dateutil parser parse识别字符串中的日期 在下一步中 我想从字符串中删除
  • 测试 python Counter 是否包含在另一个 Counter 中

    如何测试是否是pythonCounter https docs python org 2 library collections html collections Counter is 包含在另一个中使用以下定义 柜台a包含在计数器中b当且
  • 基于代理的模拟:性能问题:Python vs NetLogo & Repast

    我正在 Python 3 中复制一小段 Sugarscape 代理模拟模型 我发现我的代码的性能比 NetLogo 慢约 3 倍 这可能是我的代码的问题 还是Python的固有限制 显然 这只是代码的一个片段 但 Python 却花费了三分
  • Python 函数可以从作用域之外赋予新属性吗?

    我不知道你可以这样做 def tom print tom s locals locals def dick z print z name z name z guest Harry print z guest z guest print di
  • 如何在ipywidget按钮中显示全文?

    我正在创建一个ipywidget带有一些文本的按钮 但按钮中未显示全文 我使用的代码如下 import ipywidgets as widgets from IPython display import display button wid
  • Flask如何获取请求的HTTP_ORIGIN

    我想用我自己设置的 Access Control Allow Origin 标头做出响应 而弄清楚请求中的 HTTP ORIGIN 参数在哪里似乎很混乱 我在用着烧瓶 0 10 1 以及HTTP ORIGIN似乎是这个的特点之一object
  • 无法在 Python 3 中导入 cProfile

    我试图将 cProfile 模块导入 Python 3 3 0 但出现以下错误 Traceback most recent call last File
  • 将图像分割成多个网格

    我使用下面的代码将图像分割成网格的 20 个相等的部分 import cv2 im cv2 imread apple jpg im cv2 resize im 1000 500 imgwidth im shape 0 imgheight i
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 每个 X 具有多个 Y 值的 Python 散点图

    我正在尝试使用 Python 创建一个散点图 其中包含两个 X 类别 cat1 cat2 每个类别都有多个 Y 值 如果每个 X 值的 Y 值的数量相同 我可以使用以下代码使其工作 import numpy as np import mat
  • 如何在 Python 中追加到 JSON 文件?

    我有一个 JSON 文件 其中包含 67790 1 kwh 319 4 现在我创建一个字典a dict我需要将其附加到 JSON 文件中 我尝试了这段代码 with open DATA FILENAME a as f json obj js
  • 类型错误:预期单个张量时的张量列表 - 将 const 与 tf.random_normal 一起使用时

    我有以下 TensorFlow 代码 tf constant tf random normal time step batch size 1 1 我正进入 状态TypeError List of Tensors when single Te
  • Conda SafetyError:文件大小不正确

    使用创建 Conda 环境时conda create n env name python 3 6 我收到以下警告 Preparing transaction done Verifying transaction SafetyError Th
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • 发送用户注册密码,django-allauth

    我在 django 应用程序上使用 django alluth 进行身份验证 注册 我需要创建一个自定义注册表单 其中只有一个字段 电子邮件 密码将在服务器上生成 这是我创建的表格 from django import forms from

随机推荐