背景
我有一个有 5 个标签的多标签分类问题(例如[1 0 1 1 0]
)。因此,我希望我的模型能够改进固定召回率、精确召回率 AUC 或 ROC AUC 等指标。
使用损失函数没有意义(例如binary_crossentropy
)这与我想要优化的性能测量没有直接关系。因此,我想使用 TensorFlow 的global_objectives.recall_at_precision_loss()
或类似于损失函数。
- 相关GitHub:https://github.com/tensorflow/models/tree/master/research/global_objectives
- 相关论文(不可分解目标的可扩展学习):https://arxiv.org/abs/1608.04802
非公制
我不是在寻求实施tf.metrics
。我已经成功做到了以下几点:https://stackoverflow.com/a/50566908/3399066
Problem
我认为我的问题可以分为2个问题:
- 如何使用
global_objectives.recall_at_precision_loss()
或类似的?
- 如何在具有 TF 后端的 Keras 模型中使用它?
问题1
有一个文件叫loss_layers_example.py
on 全球目标 GitHub 页面(同上)。不过,由于我对TF没有太多经验,所以不太明白如何使用它。另外,谷歌搜索TensorFlow recall_at_precision_loss example
or TensorFlow Global objectives example
不会给我任何更清楚的例子。
我该如何使用global_objectives.recall_at_precision_loss()
在一个简单的 TF 示例中?
问题2
会像(在 Keras 中):model.compile(loss = ??.recall_at_precision_loss, ...)
足够?
我的感觉告诉我,由于使用了全局变量,事情比这更复杂loss_layers_example.py
.
如何使用类似于的损失函数global_objectives.recall_at_precision_loss()
在喀拉斯?
与马蒂诺的答案类似,但会从输入推断形状(将其设置为固定批量大小对我来说不起作用)。
外部函数并不是严格必要的,但在配置损失函数时传递参数感觉更自然,尤其是当您的包装器在外部模块中定义时。
import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss
def get_precision_at_recall_loss(target_recall):
def precision_at_recall_loss_wrapper(y_true, y_pred):
y_true = K.reshape(y_true, (-1, 1))
y_pred = K.reshape(y_pred, (-1, 1))
return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
return precision_at_recall_loss_wrapper
然后,在编译模型时:
TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)