class balanced loss pytorch 实现

2023-05-16

cb loss pytorch 实现,可直接调用
参考:https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py

import numpy as np
import torch
import torch.nn.functional as F



def focal_loss(logits, labels, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.

    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).

    Args:
      logits: A float tensor of size [batch, num_classes].
      labels: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.

    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """
    bce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))

    loss = modulator * bce_loss

    weighted_loss = alpha * loss
    loss = torch.sum(weighted_loss)
    loss /= torch.sum(labels)
    return loss


class ClassBalancedLoss(torch.nn.Module):
    def __init__(self, samples_per_class=None, beta=0.9999, gamma=0.5, loss_type="focal"):
        super(ClassBalancedLoss, self).__init__()
        if loss_type not in ["focal", "sigmoid", "softmax"]:
            loss_type = "focal"
        if samples_per_class is None:
            num_classes = 5000
            samples_per_class = [1] * num_classes
        effective_num = 1.0 - np.power(beta, samples_per_class)
        weights = (1.0 - beta) / np.array(effective_num)
        self.constant_sum = len(samples_per_class)
        weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)
        self.class_weights = weights
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type


    def update(self, samples_per_class):
        if samples_per_class is None:
            return
        effective_num = 1.0 - np.power(self.beta, samples_per_class)
        weights = (1.0 - self.beta) / np.array(effective_num)
        self.constant_sum = len(samples_per_class)
        weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)
        self.class_weights = weights



    def forward(self, x, y):
        _, num_classes = x.shape
        labels_one_hot = F.one_hot(y, num_classes).float()
        weights = torch.tensor(self.class_weights, device=x.device).index_select(0, y)
        weights = weights.unsqueeze(1)
        if self.loss_type == "focal":
            cb_loss = focal_loss(x, labels_one_hot, weights, self.gamma)
        elif self.loss_type == "sigmoid":
            cb_loss = F.binary_cross_entropy_with_logits(x, labels_one_hot, weights)
        else:  # softmax
            pred = x.softmax(dim=1)
            cb_loss = F.binary_cross_entropy(pred, labels_one_hot, weights)
        return cb_loss


def test():
    torch.manual_seed(123)
    batch_size = 10
    num_classes = 5
    x = torch.rand(batch_size, num_classes)
    y = torch.randint(0, 5, size=(batch_size,))
    samples_per_class = [1, 2, 3, 4, 5]
    loss_type = "focal"
    loss_fn = ClassBalancedLoss(samples_per_class, loss_type=loss_type)
    loss = loss_fn(x, y)
    print(loss)


if __name__ == '__main__':
    test()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

class balanced loss pytorch 实现 的相关文章

随机推荐

  • 阿克曼前轮转向车gazebo模型

    想要一个阿克曼转向结构车的gazebo模型 xff0c 要求能够用ros话题控制前进速度和前轮转角 令人惊讶的是 xff0c 网上基本没有这种模型 racecar模型 首先古月居提供了一个racecar的模型 xff0c 可以控制速度和前轮
  • Jetson TX2 在docker容器中import torch 报错的处理方式

    1 Jetson TX2 信息 xff1a 驱动版本 xff1a JetPack 4 6 1 2 docker信息 xff1a docker 镜像 xff1a pull 了 nvcr io nvidia l4t ml r32 7 1 py3
  • 史上最详细的PID教程——理解PID原理及优化算法

    Matlab动态PID仿真及PID知识梳理 云社区 华为云 huaweicloud com 位置式PID与增量式PID区别浅析 Z小旋 CSDN博客 增量式pid https zhuanlan zhihu com p 38337248 期望
  • JAVA经典试卷(理工)

    一 判断题 xff08 本大题共20小题 xff0c 每小题1分 xff0c 总计20分 xff09 1 xff0e final类能派生子类 2 xff0e 子类要调用父类的方法 xff0c 必须使用super关键字 3 xff0e Jav
  • git与gitee学习笔记

    随着时间推移 xff0c 除去常量 xff0c 任何事物都是在变化的 xff0c 如果用一根曲线表示 xff0c 横轴代表时间 xff0c 纵轴代表事物量 xff0c 那么所绘制的曲线 xff0c 在时间足够长的情况下 xff0c 必然是高
  • docker基础命令操作---镜像操作

    1 搜索官方仓库镜像 xff1a docker search image name 镜像名 例如 xff1a docker search nginx 命令执行结果参数说明 xff1a 参数 说明 NAME 镜像名称 DESCRIPTION
  • ESP8266连接天猫精灵(一)

    背景 接触天猫精灵后 xff0c 就想作一些小东西能接入天猫精灵 查看官网的文档后 xff0c 选择了ESP系列 xff0c 官方在文档中也比较推荐 读技术文档是个很难受的事情 xff0c 容易犯困 xff0c 最好有可以操作的设备 准备如
  • Windows下Boost库的安装与使用

    目录 1 基本介绍 2 下载安装 3 配置boost环境 xff08 VS2010 xff09 4 测试 1 基本介绍 Boost库是为C 43 43 语言标准库提供扩展的一些C 43 43 程序库的总称 xff0c 由Boost社区组织开
  • 嵌入式JetSon TX2上使用RealSense D435 (外加IMU芯片) 运行RTAB-Map与VINS-MONO的全流程记录

    本周成功的在JetSon TX2上移植了Vins Mono与RTAB Map xff0c 并使用摄像头RealSense D435顺利跑通了这两个框架 中间遇到了各种各样神奇的问题 xff0c 踩坑无数 xff0c 现整理记录一下整体流程
  • 微信公众号本地开发调试 - 无公网IP,内网穿透

    文章目录 前言1 配置本地服务器2 内网穿透2 1 下载安装cpolar内网穿透2 2 创建隧道 3 测试公网访问4 固定域名4 1 保留一个二级子域名4 2 配置二级子域名 5 使用固定二级子域名进行微信开发 前言 在微信公众号开发中 x
  • opencv图像通道 8UC1?

    转载自博主 64 马卫飞 https blog csdn net maweifei article details 51221259 CV lt bit depth gt S U F C lt number of channels gt b
  • gazebo中urdf、xacro、sdf模型文件关系

    gazebo的模型是用xml格式的文本文件来描述的 具体有三种形式 xff1a urdf xacro sdf urdf urdf是老的gazebo模型格式 xff0c 本身有一些缺陷 xff0c 也缺一些功能 但是网上很多gazebo模型都
  • 1_树莓派开启ssh服务

    树莓派3 开启 SSH 服务 原文链接 xff1a https blog csdn net qq 16775293 article details 88385393 文章目录 1 使用管理工具2 启动服务3 自动启动服务 3 1 Windo
  • 树莓派4b串口通信配置

    树莓派4b本身是两个串口 xff0c 运行ls dev al如下 xff1a 请注意 xff1a 在默认状态下 xff0c serial0 就是GPIO14 15 是映射到ttyS0的 xff08 就是MINI串口 xff1a dev tt
  • Pandas第三次作业20200907

    练习1 读取北向 csv 指定trade date为行索引 查看数据的基本信息 有无缺失值 对其缺失值进行处理 删除缺失值所在行 查看数据的基本信息 查看数据是否清洗完毕 index列没啥用 将index列删除 观察数据是否有重复行 将重复
  • 新手入门板卡硬件调试

    硬件电路调试步骤 新手入门板卡硬件调试一看 观察焊接情况二测 测量阻抗三接触式上电调试遇到的问题一般解决思路电源供电运放出现震荡测量时GND的选取振铃现象 新手入门板卡硬件调试 一看 观察焊接情况 1 拿到板卡后 xff0c 首先观察下焊接
  • 用shell 命令获取占用cpu 最多的前五位

    通常情况下使用ps axu 来获得系统中所有进程占用资源情况 xff0c 通常也可以使用top 命令来动态的获得系统中资源占用最多的进程 假设我们使用ps aux gt file tmp来获取linux系统中的进程占用资源情况 xff0c
  • 关于准确率accuracy和召回率recall的理解

    假设有100个样本 xff0c 其中正样本70 xff0c 负样本30 xff0c 这个是由数据集本身决定的 xff0c 机器要做的就是判别这100个样本中哪几个样本是正样本 xff0c 哪几个样本是负样本 现在机器做出了预测 xff1a
  • pytorch BERT文本分类保姆级教学

    pytorch BERT文本分类保姆级教学 本文主要依赖的工具为huggingface的transformers xff0c 更详细的解释可以查阅文档 定义模型 模型定义主要是tokenizer config和model的定义 xff0c
  • class balanced loss pytorch 实现

    cb loss pytorch 实现 xff0c 可直接调用 参考 xff1a https github com vandit15 Class balanced loss pytorch blob master class balanced