focal loss的几种实现版本(Keras/Tensorflow)

2023-11-02

起源于在工作中使用focal loss遇到的一个bug,我仔细的学习多个靠谱的focal loss讲解及实现版本

通过测试,我发现了这样一个奇怪的现象,几乎每个版本的focal loss实现对同样的输入计算出的loss都是不同的。

通过仔细的比对和思考,我总结了三种我认为正确的focal loss实现方法,并将代码分享出来。

完整的代码我整理到了我的github代码库AI-Toolbox中,代码戳这里

何为focal loss

focal loss 是随网络RetinaNet一起提出的一个令人惊艳的损失函数 paper 下载,主要针对的是解决正负样本比例严重偏斜所产生的模型难以训练的问题。

这里假设你对focal loss有所了解,简单回顾下公式 ,focal loss的定义如下:
focal loss
其中
pt
公式中 γ {\gamma} γ α {\alpha} α是两个可以调节的超参数。

γ {\gamma} γ的含义更好理解一些,其作用是削弱那些模型已经能够较好预测的样本产生损失的权重,使模型更专注于学习那些较难的hard case。

α t {\alpha}_t αt的定义,原文中的表述是:

For notational convenience, we define αt analogously to how we defined pt

也就是说, α t {\alpha}_t αt的定义可以同理于 p t p_t pt的定义。它的作用是平衡类别之间的权重。

这里补充一句,网上能够找到的各种不同版本的focal loss实现,分歧基本都出现在这里。由于focal loss最初是伴随着目标检测中判断某个区域是物体or背景(二分类问题)出现的,当我们使用focal loss来解决更一般化的问题时(比如多分类问题、多标签预测问题), α t {\alpha}_t αt 如何定义便会产生分歧,很难说哪种是绝对正统的,因为不同的定义赋予了损失函数不同的功能,可以针对不同的问题。

让我们来看看,我总结的三种实现版本。

focal loss for binary classification

针对二分类版本的 focal loss 实现

def binary_focal_loss(gamma=2, alpha=0.25):
    """
    Binary form of focal loss.
    适用于二分类问题的focal loss
    
    focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
        where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    alpha = tf.constant(alpha, dtype=tf.float32)
    gamma = tf.constant(gamma, dtype=tf.float32)

    def binary_focal_loss_fixed(y_true, y_pred):
        """
        y_true shape need be (None,1)
        y_pred need be compute after sigmoid
        """
        y_true = tf.cast(y_true, tf.float32)
        alpha_t = y_true*alpha + (K.ones_like(y_true)-y_true)*(1-alpha)
    
        p_t = y_true*y_pred + (K.ones_like(y_true)-y_true)*(K.ones_like(y_true)-y_pred) + K.epsilon()
        focal_loss = - alpha_t * K.pow((K.ones_like(y_true)-p_t),gamma) * K.log(p_t)
        return K.mean(focal_loss)
    return binary_focal_loss_fixed

在使用本损失函数前,假设你已经将每个样本使用sigmoid映射成了一个0-1之间的数,代表二分类的概率。

在keras中使用此函数作为损失函数,只需在编译模型时指定损失函数为focal loss:

model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=optimizer)

focal loss for multi category 版本1

针对多分类问题或多标签问题的 focal loss 实现1.

前面已经提到网上不同的实现版本中 α t {\alpha}_t αt的定义存在一定的分歧

当我们使用 α t {\alpha}_t αt来控制不同类别 / 标签 的权重时,实现代码如下:

def multi_category_focal_loss1(alpha, gamma=2.0):
    """
    focal loss for multi category of multi label problem
    适用于多分类或多标签问题的focal loss
    alpha用于指定不同类别/标签的权重,数组大小需要与类别个数一致
    当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为loss
    Usage:
     model.compile(loss=[multi_category_focal_loss1(alpha=[1,2,3,2], gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    epsilon = 1.e-7
    alpha = tf.constant(alpha, dtype=tf.float32)
    #alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)
    #alpha = tf.constant_initializer(alpha)
    gamma = float(gamma)
    def multi_category_focal_loss1_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
        ce = -tf.log(y_t)
        weight = tf.pow(tf.subtract(1., y_t), gamma)
        fl = tf.matmul(tf.multiply(weight, ce), alpha)
        loss = tf.reduce_mean(fl)
        return loss
    return multi_category_focal_loss1_fixed

注意,你需要将 α {\alpha} α指定为一个数组,数组大小需要与类别个数一致,代表着每一个类别对应的权重。

当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为loss。

我们将核心函数copy出来做一个简单的测试,来验证 α {\alpha} α平衡类别间权重的有效性。

import os
from keras import backend as K
import tensorflow as tf
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def multi_category_focal_loss1(y_true, y_pred):
    epsilon = 1.e-7
    gamma = 2.0
    #alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)
    alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
    y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
    ce = -tf.log(y_t)
    weight = tf.pow(tf.subtract(1., y_t), gamma)
    fl = tf.matmul(tf.multiply(weight, ce), alpha)
    loss = tf.reduce_mean(fl)
    return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0]])
Y_pred = np.array([[0.3, 0.99, 0.8, 0.97, 0.85], [0.9, 0.05, 0.1, 0.09, 0]], dtype=np.float32)
print(K.eval(multi_category_focal_loss1(Y_true, Y_pred)))

假设我们正在处理一个5个输出的多label预测问题,按照上面的示例,假设我们的模型对于第一个label相比于其它标签的预测很糟糕(这可能是由于第一个label出现的概率很小,在算损失时没有话语权导致的)。

上面代码的运算结果是1.2347984

我们使用 α {\alpha} α来调节第一个label的权重,尝试将 α {\alpha} α修改为:

alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)

重新运行,损失增大为2.4623184,说明损失函数成功的放大了第一个类别的权重,会使模型更重视第一个label的正确预测。

focal loss for multi category 版本2

针对多分类问题或多标签问题的 focal loss 实现2.

当我们使用 α t {\alpha}_t αt 来控制真值y_true为 1 or 0 时的权重时

即 y = 1 时的权重为 α {\alpha} α, y = 0时的权重为 1 − α 1-{\alpha} 1α

实现代码如下:

def multi_category_focal_loss2(gamma=2., alpha=.25):
    """
    focal loss for multi category of multi label problem
    适用于多分类或多标签问题的focal loss
    alpha控制真值y_true为1/0时的权重
        1的权重为alpha, 0的权重为1-alpha
    当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss
    当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小
    当模型过于惰性(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征)
        尝试将alpha调大,鼓励模型进行预测出1。
    Usage:
     model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    epsilon = 1.e-7
    gamma = float(gamma)
    alpha = tf.constant(alpha, dtype=tf.float32)

    def multi_category_focal_loss2_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
    
        alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)
        y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
        ce = -tf.log(y_t)
        weight = tf.pow(tf.subtract(1., y_t), gamma)
        fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
        loss = tf.reduce_mean(fl)
        return loss
    return multi_category_focal_loss2_fixed

注意,你需要将 α {\alpha} α指定为一个数组,数组大小需要与类别个数一致,代表着每一个类别对应的权重。

当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss

当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小

当模型过于“懒惰”时(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征),尝试将alpha调大,鼓励模型预测出1。

同样地,我们将核心函数copy出来做一个简单的测试,来验证 α {\alpha} α平衡0-1权重的有效性。

import os
from keras import backend as K
import tensorflow as tf
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def multi_category_focal_loss2_fixed(y_true, y_pred):
    epsilon = 1.e-7
    gamma=2.
    alpha = tf.constant(0.5, dtype=tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)

    alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)
    y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
    ce = -tf.log(y_t)
    weight = tf.pow(tf.subtract(1., y_t), gamma)
    fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
    loss = tf.reduce_mean(fl)
    return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 1, 1, 1, 1]])
Y_pred = np.array([[0.9, 0.99, 0.8, 0.97, 0.85], [0.9, 0.95, 0.91, 0.99, 1]], dtype=np.float32)
print(K.eval(multi_category_focal_loss2_fixed(Y_true, Y_pred)))

仍然假设我们正在处理一个5个输出的多label预测问题

按照上面的示例,假设这次我们遇到的问题是,所有的标签都会有很高的概率出现1,这时我们的模型发现了一个投机取巧的办法,将每个结果都预测为1,即可得到很小的loss,于是模型严重的欠拟合。

上面代码的运算结果是0.093982555,如我们所料,损失并不大,这显然会影响模型成功收敛。

我们使用 α {\alpha} α来抑制模型输出1的权重,尝试将 α {\alpha} α修改为:

alpha = tf.constant(0.25, dtype=tf.float32)

重新运行,损失增大为0.14024596,说明损失函数成功的放大了这种投机行为的损失。

参考文献

focal loss paper
Keras自定义Loss函数
Keras中自定义复杂的loss函数
github: focal-loss-keras 实现1
github: focal-loss-keras 实现2
kaggle kernel: FocalLoss for Keras
Focal Loss理解
应用:Multi-class classification with focal loss for imbalanced datasets

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

focal loss的几种实现版本(Keras/Tensorflow) 的相关文章

  • 使用 psycopg2 在 python 中执行查询时出现“编程错误:语法错误位于或附近”

    我正在运行 Python v 2 7 和 psycopg2 v 2 5 我有一个 postgresql 数据库函数 它将 SQL 查询作为文本字段返回 我使用以下代码来调用该函数并从文本字段中提取查询 cur2 execute SELECT
  • 使 django 服务器可以在 LAN 中访问

    我已经安装了Django服务器 可以如下访问 http localhost 8000 get sms http 127 0 0 1 8000 get sms 假设我的IP是x x x x 当我这样做时 从同一网络下的另一台电脑 my ip
  • Python(Selenium):如何通过登录重定向/组织登录登录网站

    我不是专业程序员 所以请原谅任何愚蠢的错误 我正在做一些研究 我正在尝试使用 Selenium 登录数据库来搜索大约 1000 个术语 我有两个问题 1 重定向到组织登录页面后如何使用 Selenium 登录 2 如何检索数据库 在我解决
  • 使用带有关键字参数的 map() 函数

    这是我尝试使用的循环map功能于 volume ids 1 2 3 4 5 ip 172 12 13 122 for volume id in volume ids my function volume id ip ip 我有办法做到这一点
  • Python - StatsModels、OLS 置信区间

    在 Statsmodels 中 我可以使用以下方法拟合我的模型 import statsmodels api as sm X np array 22000 13400 47600 7400 12000 32000 28000 31000 6
  • 如何从网页中嵌入的 Tableau 图表中抓取工具提示值

    我试图弄清楚是否有一种方法以及如何使用 python 从网页中的 Tableau 嵌入图形中抓取工具提示值 以下是当用户将鼠标悬停在条形上时带有工具提示的图表示例 我从要从中抓取的原始网页中获取了此网址 https covid19 colo
  • 是否可以忽略一行的pyright检查?

    我需要忽略一行的pyright 检查 有什么特别的评论吗 def create slog group SLogGroup data Optional dict None SLog insert one SLog group group da
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • Python pickle:腌制对象不等于源对象

    我认为这是预期的行为 但想检查一下 也许找出原因 因为我所做的研究结果是空白 我有一个函数可以提取数据 创建自定义类的新实例 然后将其附加到列表中 该类仅包含变量 然后 我使用协议 2 作为二进制文件将该列表腌制到文件中 稍后我重新运行脚本
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 如何加速Python中的N维区间树?

    考虑以下问题 给定一组n间隔和一组m浮点数 对于每个浮点数 确定包含该浮点数的区间子集 这个问题已经通过构建一个解决区间树 https en wikipedia org wiki Interval tree 或称为范围树或线段树 已经针对一
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • 如何使用Python创建历史时间线

    So I ve seen a few answers on here that helped a bit but my dataset is larger than the ones that have been answered prev
  • python获取上传/下载速度

    我想在我的计算机上监控上传和下载速度 一个名为 conky 的程序已经在 conky conf 中执行了以下操作 Connection quality alignr wireless link qual perc wlan0 downspe
  • Fabric env.roledefs 未按预期运行

    On the 面料网站 http docs fabfile org en 1 10 usage execution html 给出这个例子 from fabric api import env env roledefs web hosts
  • 对年龄列进行分组/分类

    我有一个数据框说df有一个柱子 Ages gt gt gt df Age 0 22 1 38 2 26 3 35 4 35 5 1 6 54 我想对这个年龄段进行分组并创建一个像这样的新专栏 If age gt 0 age lt 2 the
  • Python:如何将列表列表的元素转换为无向图?

    我有一个程序 可以检索 PubMed 出版物列表 并希望构建一个共同作者图 这意味着对于每篇文章 我想将每个作者 如果尚未存在 添加为顶点 并添加无向边 或增加每个合著者之间的权重 我设法编写了第一个程序 该程序检索每个出版物的作者列表 并
  • 导入错误:没有名为 site 的模块 - mac

    我已经有这个问题几个月了 每次我想获取一个新的 python 包并使用它时 我都会在终端中收到此错误 ImportError No module named site 我不知道为什么会出现这个错误 实际上 我无法使用任何新软件包 因为每次我
  • 如何使用 Pycharm 安装 tkinter? [复制]

    这个问题在这里已经有答案了 I used sudo apt get install python3 6 tk而且效果很好 如果我在终端中打开 python Tkinter 就可以工作 但我无法将其安装在我的 Pycharm 项目上 pip

随机推荐

  • mac os 10.12安全性与隐私没有任何来源的解决办法

    到mac os10 12后 有很多签名不对的软件安装会装不了 比如 解决办法如下 在命令行中输入 sudo spctl master disable 这样就可以在系统偏好设置 安全性与隐私中看到任何来源了 这样像很多本来安装不了的软件又可以
  • ValidPalindrome(回文字符串的判断)

    author LemonLin Description StringValidPalindrome date 2019 5 9 16 40 Given a string determine if it is a palindrome con
  • TensorFlow:数据集加载

    TensorFlow 数据集加载 数据集加载 数据集加载 1 keas datasets tensoflow keras提供了keras datasets的接口 常见的数据集 Boston housing price regerssion
  • 正则的校验

    验证版本号的正则表达式 要求 必须是三位 x x x的形式 每位x的范围分别为1 99 0 99 0 99 不允许的情况0 x x 01 x x x 0x x x 00 x x x 00 x x 0x 满足这些条件的正则为 1 9 d 1
  • shell 批量创建多个用户(导入用户表)

    批量创建用户和密码 查看 创建成功 复制代码 bin bash ULIST cat root user txt for UNAME in ULIST do useradd UNAME echo 123456 passwd stdin UNA
  • 记录素材帖,日常吐槽帖

    本人考研繁忙 无暇写博客 明年暑假腹泻式更新 等我上岸 冲冲冲 C语言二维指针的表示 int ra 2 matrix js的参数代理 考研数据结构的所有重要算法 react让组件强制更新的一个损招 Taro全面更新版本的指令 c 数组传参后
  • Selenium防踩坑 - StaleElementReferenceException 解决方案

    主要内容 1 异常原因 2 解决方案 1 异常原因 在执行脚本时 有时候引用元素对象会抛出如下异常 selenium common exceptions StaleElementReferenceException Message stal
  • 企业建设数字化工厂之前需要准备哪些硬件设施

    随着数字化技术的快速发展 数字化工厂已经成为了企业建设的重要方向 数字化工厂管理系统能够提高生产效率 降低成本 保证产品质量 为企业可持续发展提供有力支持 然而 建设数字化工厂需要准备一系列的硬件设施 以确保数字化工厂的正常运行 那么企业建
  • 关于文件上传漏洞的观点(upload-labs第九关)

    关于文件上传漏洞的观点 upload labs第九关 is upload false msg null if isset POST submit if file exists UPLOAD PATH deny ext array php p
  • java-web 过滤器 & 监听器 & 拦截器

    Tomcat 的容器分为四个等级 真正管理 Servlet 的容器是 Context 容器 一个 Context 对应一个 Web 工程 在 Tomcat 的配置文件中可以很容易发现这一点 如下 Context 配置参数
  • 有关校园网无法开启wifi的简单解决方法

    作为一个新时代的大学生 没有wifi的世界就是个噩梦 以前用的猎豹wifi 但发现卸载猎豹wifi后无法登陆校园网后 果断抛弃了这个家伙 现在使用的是一个叫360免费wifi的东西 现在开着校园网客户端的情况下打开360wifi 但是问题来
  • 如何用python远程探查每天的网页访问记录

    前言 利用Python制作远程查看别人电脑的操作记录 与其它教程类似 都是通过邮件返回 利用程序得到目标电脑浏览器当中的访问记录 生产一个文本并发送到你自己的邮箱 当然这个整个过程除了你把python程序植 入目标电脑外 其它的操作都是自动
  • nginx 报错[emerg]: unknown directive “锘? in E:\nginx-1.18.0/conf/nginx.conf:3

    报错 nginx 报错 emerg 32408 14080 unknown directive 锘 in E nginx 1 18 0 conf nginx conf 3 原因 使用nginx服务时 用txt记事本打开编辑了nginx co
  • 清除浮动的五种方法以及优缺点

    方法一 额外标签法 给谁清除浮动 就在其后额外添加一个空白标签 给其设置clear both 优点 通俗易懂 书写方便 缺点 添加许多无意义的标签 结构化比较差 clear both 本质就是闭合浮动 就是让父盒子闭合出口和入口 不让子盒子
  • Python实例:用Pandas处理表格(简单的增删改查)

    目录 任务描述 实现过程 任务描述 描述 现有一个excel表格 补充SCI模板 其中包括6个子表 中科院1区 表1 JCR Q1 表2 教研室补充 表 CCF A 表 CCF B 表 CCF C 表 每个表格第一列为期刊名称 需要为这些期
  • 基于springboot+vue的网上商城管理系统,附源码+数据库+lw文档+PPT,适合课程设计、毕业设计

    1 项目介绍 在Internet高速发展的今天 我们生活的各个领域都涉及到计算机的应用 其中包括网上图书商城的网络应用 在外国网上图书商城已经是很普遍的方式 不过国内的管理网站可能还处于起步阶段 网上图书商城具有网上图书信息管理功能的选择
  • Visual Studio在Release模式下开启debug调试,编译器提示变量已被优化掉,因而不可用

    系列文章目录 文章目录 系列文章目录 前言 一 解决办法 1 修改工程属性 参考 前言 我们在编写代码的时候 如果用到别人的库 而别人只提供了release版本 所有我们也只能生成release版本的工程 但是 我们又想调试代码 如果我们直
  • vue3 naiveui 自定义v-loading指令

    1 在sr目录下创建loading文件夹 包含index ts和index vue 2 index ts import render VNode createVNode from vue import Loading from index
  • 【Java基础知识 12】java枚举详解

    Java学习路线 搬砖工逆袭Java架构师 简介 Java领域优质创作者 CSDN哪吒公众号作者 Java架构师奋斗者 扫描主页左侧二维码 加入群聊 一起学习 一起进步 欢迎点赞 收藏 留言 目录 一 基本概念 二 枚举的优缺点 1 优点
  • focal loss的几种实现版本(Keras/Tensorflow)

    起源于在工作中使用focal loss遇到的一个bug 我仔细的学习多个靠谱的focal loss讲解及实现版本 通过测试 我发现了这样一个奇怪的现象 几乎每个版本的focal loss实现对同样的输入计算出的loss都是不同的 通过仔细的