将 TensorFlow 损失全局目标 (recall_at_ precision_loss) 与 Keras(而非指标)结合使用

2023-11-22

背景

我有一个有 5 个标签的多标签分类问题(例如[1 0 1 1 0])。因此,我希望我的模型能够改进固定召回率、精确召回率 AUC 或 ROC AUC 等指标。

使用损失函数没有意义(例如binary_crossentropy)这与我想要优化的性能测量没有直接关系。因此,我想使用 TensorFlow 的global_objectives.recall_at_precision_loss()或类似于损失函数。

  • 相关GitHub:https://github.com/tensorflow/models/tree/master/research/global_objectives
  • 相关论文(不可分解目标的可扩展学习):https://arxiv.org/abs/1608.04802

非公制

我不是在寻求实施tf.metrics。我已经成功做到了以下几点:https://stackoverflow.com/a/50566908/3399066

Problem

我认为我的问题可以分为2个问题:

  1. 如何使用global_objectives.recall_at_precision_loss()或类似的?
  2. 如何在具有 TF 后端的 Keras 模型中使用它?

问题1

有一个文件叫loss_layers_example.py on 全球目标 GitHub 页面(同上)。不过,由于我对TF没有太多经验,所以不太明白如何使用它。另外,谷歌搜索TensorFlow recall_at_precision_loss example or TensorFlow Global objectives example不会给我任何更清楚的例子。

我该如何使用global_objectives.recall_at_precision_loss()在一个简单的 TF 示例中?

问题2

会像(在 Keras 中):model.compile(loss = ??.recall_at_precision_loss, ...)足够? 我的感觉告诉我,由于使用了全局变量,事情比这更复杂loss_layers_example.py.

如何使用类似于的损失函数global_objectives.recall_at_precision_loss()在喀拉斯?


与马蒂诺的答案类似,但会从输入推断形状(将其设置为固定批量大小对我来说不起作用)。

外部函数并不是严格必要的,但在配置损失函数时传递参数感觉更自然,尤其是当您的包装器在外部模块中定义时。

import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss

def get_precision_at_recall_loss(target_recall): 
    def precision_at_recall_loss_wrapper(y_true, y_pred):
        y_true = K.reshape(y_true, (-1, 1)) 
        y_pred = K.reshape(y_pred, (-1, 1))   
        return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
    return precision_at_recall_loss_wrapper

然后,在编译模型时:

TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将 TensorFlow 损失全局目标 (recall_at_ precision_loss) 与 Keras(而非指标)结合使用 的相关文章

随机推荐