什么时候在keras中使用sample_weights合适?

2023-12-24

根据这个question https://stackoverflow.com/questions/43459317/keras-class-weight-vs-sample-weights-in-the-fit-generator,我了解到class_weight in keras在训练期间应用加权损失,并且sample_weight如果我对所有训练样本没有同等的信心,就会按样本做一些事情。

所以我的问题是,

  1. 验证期间的损失是否由class_weight,或者仅在训练期间加权?
  2. 我的数据集有 2 个类别,实际上类别分布并没有严重不平衡。该比例约为。 1.7 : 1.是否有必要使用class_weight平衡损失甚至使用过采样?可以将稍微不平衡的数据保留为通常的数据集处理吗?
  3. 我可以简单地考虑一下吗sample_weight作为我赋予每个训练样本的权重?而且我的训练样本可以得到同等的置信度,所以我可能不需要使用这个。

  1. 从 keras 文档中可以看出

类别权重:可选字典将类索引(整数)映射到权重(浮点)值,用于对损失函数进行加权(仅在训练期间)。这对于告诉模型“更多地关注”代表性不足的类别的样本很有用。

So class_weight只影响训练期间的损失。我本人一直有兴趣了解在测试和训练期间如何处理类和样本权重。查看keras github repo以及metric和loss的代码,似乎loss或metric都没有受到它们的影响。打印的值很难在训练代码中跟踪,例如model.fit()及其相应的tensorflow后端训练函数。所以我决定制作一个测试代码来测试可能的场景,请参阅下面的代码。结论是,两者class_weight and sample_weight只影响训练损失,对任何指标或验证损失没有影响。有点令人惊讶的是val_sample_weights(你可以指定)似乎什么也没做(??)。

  1. 此类问题始终取决于您的问题、日期的偏差程度以及您尝试优化模型的方式。您是否针对准确性进行了优化,那么只要训练数据与模型在生产中时的偏差相同,只需训练即可获得最佳结果,而无需任何过采样/欠采样和/或类别权重。 另一方面,如果您有一个类比另一个类更重要(或更昂贵)的东西,那么您应该对数据进行加权。例如,在预防欺诈方面,欺诈通常比非欺诈的收入要昂贵得多。我建议您尝试未加权的类别、加权的类别和一些欠采样/过采样,并检查哪一个可以提供最佳的验证结果。使用最能比较不同模型的验证函数(或编写自己的函数)(例如,根据成本对真阳性、假阳性、真阴性和假阴性进行不同的加权)。 一个相对较新的损失函数在偏斜数据的 Kaggle 竞赛中取得了很好的成绩:Focal-loss. Focal-loss减少过采样/欠采样的需要。很遗憾Focal-loss还不是 keras 中的内置函数,但可以手动编程。

  2. 是的,我认为你是对的。我通常使用sample_weight有两个原因。 1、训练数据具有某种测量不确定性,如果已知的话,可以用来对准确数据进行加权,而不是对不准确测量进行加权。或者2,我们可以对新数据赋予比旧数据更大的权重,迫使模型更快地适应新行为,而不忽略有价值的旧数据。

比较有和没有的代码class_weights and sample_weights,同时保持模型和其他一切静止。

import tensorflow as tf
import numpy as np

data_size = 100
input_size=3
classes=3

x_train = np.random.rand(data_size ,input_size)
y_train= np.random.randint(0,classes,data_size )
#sample_weight_train = np.random.rand(data_size)
x_val = np.random.rand(data_size ,input_size)
y_val= np.random.randint(0,classes,data_size )
#sample_weight_val = np.random.rand(data_size )

inputs = tf.keras.layers.Input(shape=(input_size))
pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)

model = tf.keras.models.Model(inputs=inputs, outputs=pred)

loss = tf.keras.losses.sparse_categorical_crossentropy
metrics = tf.keras.metrics.sparse_categorical_accuracy

model.compile(loss=loss , metrics=[metrics], optimizer='adam')

# Make model static, so we can compare it between different scenarios
for layer in model.layers:
    layer.trainable = False

# base model no weights (same result as without class_weights)
# model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
class_weights={0:1.,1:1.,2:1.}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100

#changing the class weights to zero, to check which loss and metric that is affected
class_weights={0:0,1:0,2:0}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100

#changing the sample_weights to zero, to check which loss and metric that is affected
sample_weight_train = np.zeros(100)
sample_weight_val = np.zeros(100)
model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100

使用权重和不使用权重之间存在一些小偏差(即使所有权重都是一),可能是由于对加权和未加权数据使用不同的后端函数进行拟合,或者是由于舍入误差?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

什么时候在keras中使用sample_weights合适? 的相关文章

  • 将数据从 python pandas 数据框导出或写入 MS Access 表

    我正在尝试将数据从 python pandas 数据框导出到现有的 MS Access 表 我想用已更新的数据替换 MS Access 表 在 python 中 我尝试使用 pandas to sql 但收到错误消息 我觉得很奇怪 使用 p
  • OpenCV Python cv2.mixChannels()

    我试图将其从 C 转换为 Python 但它给出了不同的色调结果 In C Transform it to HSV cvtColor src hsv CV BGR2HSV Use only the Hue value hue create
  • 如何在flask中使用g.user全局

    据我了解 Flask 中的 g 变量 它应该为我提供一个全局位置来存储数据 例如登录后保存当前用户 它是否正确 我希望我的导航在登录后在整个网站上显示我的用户名 我的观点包含 from Flask import g among other
  • 如何替换 pandas 数据框列中的重音符号

    我有一个数据框dataSwiss其中包含瑞士城市的信息 我想用普通字母替换带有重音符号的字母 这就是我正在做的 dataSwiss Municipality dataSwiss Municipality str encode utf 8 d
  • python 相当于 R 中的 get() (= 使用字符串检索符号的值)

    在 R 中 get s 函数检索名称存储在字符变量 向量 中的符号的值s e g X lt 10 r lt XVI s lt substr r 1 1 X get s 10 取罗马数字的第一个符号r并将其转换为其等效整数 尽管花了一些时间翻
  • 如何从网页中嵌入的 Tableau 图表中抓取工具提示值

    我试图弄清楚是否有一种方法以及如何使用 python 从网页中的 Tableau 嵌入图形中抓取工具提示值 以下是当用户将鼠标悬停在条形上时带有工具提示的图表示例 我从要从中抓取的原始网页中获取了此网址 https covid19 colo
  • ubuntu 20.04 上无法获取卷积算法错误~tensorflow-gpu

    我有一个 NVIDIA 2070 RTX GPU 我的操作系统是 Ubuntu20 04 我已经使用 conda 安装了tensorflow gpu 包 我有not安装了 CUDA toolkit 我相信它还会安装 CUDA toolkit
  • SQLALchemy .query:类“Car”的未解析属性引用“query”

    我有一个这里已经提到的问题https youtrack jetbrains com issue PY 44557 https youtrack jetbrains com issue PY 44557 但我还没有找到解决方案 我使用 Pyt
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 如何使用 OpencV 从 Firebase 读取图像?

    有没有使用 OpenCV 从 Firebase 读取图像的想法 或者我必须先下载图片 然后从本地文件夹执行 cv imread 功能 有什么办法我可以使用cv imread link of picture from firebase 您可以
  • 从 Flask 访问 Heroku 变量

    我已经使用以下命令在 Heroku 配置中设置了数据库变量 heroku config add server xxx xxx xxx xxx heroku config add user userName heroku config add
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • Pygame:有没有简单的方法可以找到按下的任何字母数字的字母/数字?

    我目前正在开发的游戏需要让人们以自己的名义在高分板上计时 我对如何处理按键有点熟悉 但我只处理过寻找特定的按键 有没有一种简单的方法可以按下任意键的字母 而不必执行以下操作 for event in pygame event get if
  • python获取上传/下载速度

    我想在我的计算机上监控上传和下载速度 一个名为 conky 的程序已经在 conky conf 中执行了以下操作 Connection quality alignr wireless link qual perc wlan0 downspe
  • Jupyter Notebook 内核一直很忙

    我已经安装了 anaconda 并且 python 在 Spyder IPython 等中工作正常 但是我无法运行 python 笔记本 内核被创建 它也连接 但它始终显示黑圈忙碌符号 防火墙或防病毒软件没有问题 我尝试过禁用两者 我也无法
  • 使用 Python 绘制 2D 核密度估计

    I would like to plot a 2D kernel density estimation I find the seaborn package very useful here However after searching
  • Python:如何将列表列表的元素转换为无向图?

    我有一个程序 可以检索 PubMed 出版物列表 并希望构建一个共同作者图 这意味着对于每篇文章 我想将每个作者 如果尚未存在 添加为顶点 并添加无向边 或增加每个合著者之间的权重 我设法编写了第一个程序 该程序检索每个出版物的作者列表 并
  • 从列表指向字典变量

    假设你有一个清单 a 3 4 1 我想用这些信息来指向字典 b 3 4 1 现在 我需要的是一个常规 看到该值后 在 b 的位置内读写一个值 我不喜欢复制变量 我想直接改变变量b的内容 假设b是一个嵌套字典 你可以这样做 reduce di
  • 如何使用 Pycharm 安装 tkinter? [复制]

    这个问题在这里已经有答案了 I used sudo apt get install python3 6 tk而且效果很好 如果我在终端中打开 python Tkinter 就可以工作 但我无法将其安装在我的 Pycharm 项目上 pip
  • Statsmodels.formula.api OLS不显示截距的统计值

    我正在运行以下源代码 import statsmodels formula api as sm Add one column of ones for the intercept term X np append arr np ones 50

随机推荐

  • 如何加速“独特”数据框搜索

    我有一个数据框 其尺寸为 2377426 行 x 2 列 如下所示 Name Seq 428293 ENSE00001892940 ENSE00001929862 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
  • 如何每隔x秒重复执行一个函数?

    我想永远每 60 秒重复执行一次 Python 中的函数 就像NSTimer http web archive org web 20090823012700 http developer apple com 80 DOCUMENTATION
  • java持久化内存泄漏

    我的表中有 100 万行 我想获取所有行 但是当我尝试通过分页获取带有 jpa 的所有行时 我收到 java 堆错误 你认为我错过了什么吗 任何建议 int counter 0 while counter gt 0 javax persis
  • 在 Docker Alpine 上安装seaborn

    我正在尝试安装seaborn使用这个 Dockerfile FROM alpine latest RUN apk add update python py pip python dev RUN pip install seaborn CMD
  • 获取客户端隐藏字段的值

    单击服务器端的按钮 我将表中的列中的值分配给隐藏字段 Dim dsGetEnquiryDetails dbl usp GetEnquiryRegisterDetails Val lblEnquiryRegisterID Text AsQue
  • Docker:Opensearch 拒绝与 docker 中的 opensearch 文档中的示例连接

    我正在 docker 容器上运行 opensearch v 1 0 0 并在localhost 请考虑这个问题IS NOT和这篇文章一样 Opensearch Docker Image 无法建立新连接 Errno 111 连接被拒绝 htt
  • 对贝塞尔曲线的点进行动画处理[重复]

    这个问题在这里已经有答案了 是否可以对贝塞尔曲线的点进行动画处理 我正在尝试从直线到箭头的平滑过渡 这是该行在代码中的样子 Color Declarations UIColor white UIColor colorWithRed 1 gr
  • 计算沿轴的直方图

    有没有办法沿着 nD 数组的轴计算许多直方图 我目前使用的方法是for循环迭代所有其他轴并计算numpy histogram 对于每个生成的一维数组 import numpy import itertools data numpy rand
  • C++11/14 中的 Boost.Pointer 容器被 std::unique_ptr 废弃了?

    Does std unique ptr make Boost Pointer容器C 11 14 中的库已过时吗 在 C 98 03 中没有移动语义 并且有一个智能指针 例如shared ptr与引用计数相关overhead 对于参考计数块
  • 使 FAB 响应软键盘显示/隐藏更改

    我看过各种关于 FAB 响应屏幕底部 Snackbar 弹出窗口以及滚动敏感 FAB 的帖子 但是否有一些实施FloatingActionButton Behavior 或类似 将 FAB 移至键盘上方当它出现时 现在 当我单击某个按钮时
  • 将 IE 窗口置于屏幕前面

    我正在动态创建新的 IE 浏览器实例 并从那里打开一个 aspx 页面 一切正常 但浏览器没有在屏幕前面弹出 当我从那里单击它时 能够在任务栏中看到 Aspx 页面 它会出现在前面 如何在 IE 创建后立即将该页面显示在所有屏幕的前面 我已
  • 如何处理来自不同时区的日期时间

    我有一个 django 应用程序 它在数据库 postgres 中存储 UTC 的日期时间 它在世界各地都有用户 但在应用程序逻辑中 我根据本地时间范围进行了一些验证 即用户在瓜亚基尔并且整个周日都发生了一些事情 我在执行它时遇到问题并进行
  • 调用线程无法访问该对象,因为另一个线程拥有它

    我正在尝试从 PowerShell 检索打印队列列表 如下所示 但我越来越 The calling thread cannot access this object because a different thread owns it 发生
  • 如何在Python中进行二次排序?

    如果我有一个数字列表 4 2 5 1 3 我想先按某个功能对其进行排序f然后对于具有相同值的数字f我希望它按数字的大小排序 这段代码似乎不起作用 list5 sorted list5 list5 sorted list5 key lambd
  • webpack 在react.js 中无法正常工作

    我使用创建了一个 hello world 反应应用程序create react app命令 然后我尝试使用运行相同的文件webpack 但它不能正常工作 比如 ico css文件是not rendering到屏幕上 请帮我解决这个问题 we
  • 在 Observable Angular js 2 中迭代 json 字符串

    以下是我的html代码 tr td c name td td c skill td tr 在我的 json 中 name abc skill xyz 这是可行的 但我需要迭代这个 json 字符串 var obj a 1 b 2 for v
  • 如何在运行时重新转换类?

    我正在尝试修改一个已加载到 JVM 中的类 我找到的解决方案是这样的 将代理附加到 PID 指定的 JVM 例如8191 代码 AttachTest 从 JVM 中已加载的类中找到您要修改的类 例如 8191 使用仪器添加变压器 代码 Ag
  • C++ 进程因状态 3 混乱而终止

    我对编程非常陌生 但在过去一周左右的时间里一直在关注 C 教程并积累了许多 PDF 来帮助我 我在其中或网上找不到任何足够清楚地回答我的问题的内容 请原谅我的新手 相关代码 日志文件 hpp HEADER CLASS INTERFACE F
  • 检索 Linkedin 视频帖子 (ugcPost API) 的缩略图

    我尝试使用 ugcPost api 检索视频帖子的缩略图 但没有成功 我总是检索一个空的缩略图数组 关于文档检索 UGC 帖子 https learn microsoft com en us linkedin marketing integ
  • 什么时候在keras中使用sample_weights合适?

    根据这个question https stackoverflow com questions 43459317 keras class weight vs sample weights in the fit generator 我了解到cl