精度类型

2024-05-22

使用 keras 库获得的精度如下:

model.compile(optimizer='sgd',
          loss='mse',
          metrics=[tf.keras.metrics.Precision()])

sklearn 计算出的哪种精度与 keras 计算出的精度相同?

precision_score(y_true, y_pred, average=???)
  1. macro
  2. micro
  3. weighted
  4. none

当您将 Zero_division 设置为 1 时会发生什么,如下所示?:

precision_score(y_true, y_pred, average=None, zero_division=1)

TLDR;默认为binary用于二元分类和micro用于多类分类。其他平均类型,例如None and macro也可以通过较小的修改来实现,如下所述。


这应该可以让您清楚地了解之间的差异tf.keras.Precision() and sklearn.metrics.precision_score()。让我们比较一下不同的场景。

场景一:二元分类

对于二元分类,y_true 和 y_pred 分别为 0,1 和 0-1。两者的实现都非常简单。

Sklearn 文档:仅报告 pos_label 指定的类的结果。仅当目标 (y_{true,pred}) 是二进制时才适用。

#Binary classification

from sklearn.metrics import precision_score
import tensorflow as tf

y_true = [0,1,1,1]
y_pred = [1,0,1,1]

print('sklearn precision: ',precision_score(y_true, y_pred, average='binary'))
#Only report results for the class specified by pos_label. 
#This is applicable only if targets (y_{true,pred}) are binary.

m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision:  0.6666666666666666
tf.keras precision: 0.6666667

场景2:多类分类(全局精度)

在这里,您正在使用多类标签,但您不必担心每个单独类的精度如何。您只需要一组全局 TP 和 FP 来计算总精度分数。在sklearn这是由参数设置的micro, 而在tf.keras这是默认设置Precision()

Sklearn 文档:通过计算总的真阳性、假阴性和假阳性来计算全局指标。

#Multi-class classification (global precision)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision: ',precision_score(y_true, y_pred, average='micro'))
#Calculate metrics globally by counting the total true positives, false negatives and false positives.

m.reset_states()
m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision:  0.3333333333333333
tf.keras precision: 0.33333334

场景3:多类分类(每个标签的二进制精度)

如果您想了解每个单独类别的精度,您会对这种情况感兴趣。在sklearn这是通过设置来完成的average参数为None, 而在tf.keras您必须使用以下方法分别实例化每个单独类的对象class_id.

Sklearn 文档:如果 None,则返回每个班级的分数。

#Multi-class classification (binary precision for each label)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision: ',precision_score(y_true, y_pred, average=None))
#If None, the scores for each class are returned.

#For class 0
m0 = tf.keras.metrics.Precision(class_id=0)
m0.update_state(y_true, y_pred)

#For class 1
m1 = tf.keras.metrics.Precision(class_id=1)
m1.update_state(y_true, y_pred)

#For class 2
m2 = tf.keras.metrics.Precision(class_id=2)
m2.update_state(y_true, y_pred)

mm = [m0.result().numpy(), m1.result().numpy(), m2.result().numpy()]

print('tf.keras precision:',mm)
sklearn precision:  [0.66666667 0.         0.        ]
tf.keras precision: [0.6666667, 0.0, 0.0]

场景 4:多类分类(单个二进制分数的平均值)

计算出每个类别的个体精度后,您可能需要取平均分(或加权平均值)。在sklearn,通过设置参数对各个分数进行简单平均average to macro. In tf.keras您可以通过计算上述场景中计算的各个精度的平均值来获得相同的结果。

Sklearn 文档:计算每个标签的指标,并找到它们的未加权平均值。

#Multi-class classification (Average of individual binary scores)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision (Macro): ',precision_score(y_true, y_pred, average='macro'))
print('sklearn precision (Avg of None):' ,np.average(precision_score(y_true, y_pred, average=None)))

print(' ')

print('tf.keras precision:',np.average(mm)) #mm is list of individual precision scores
sklearn precision (Macro):  0.2222222222222222
sklearn precision (Avg of None):  0.2222222222222222
 
tf.keras precision: 0.22222222

NOTE:请记住,与sklearn,您有直接预测标签的模型,并且precision_score是一个独立的方法。因此,它可以直接对预测值和实际值的标签列表进行操作。然而,tf.keras.Precision()是必须应用于二进制或多类密集输出的度量。它将无法直接使用标签。您必须为每个样本提供一个 n 长度的数组,其中 n 是类/输出密集节点的数量。

希望这能澄清两者在各种情况下的不同之处。请在以下位置找到更多详细信息sklearn文档 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.htmltf.keras 文档 https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision.


你的第二个问题——

根据 sklearn 文档,

zero_division - “warn”, 0 or 1, default=”warn”
#Sets the value to return when there is a zero division. If set to “warn”, #this acts as 0, but warnings are also raised.

这是一个异常处理标志。在计算分数的过程中,如果有一次遇到divide by zero,它会认为它等于零并发出警告。否则,如果显式设置为 1,则将其设置为 1。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

精度类型 的相关文章

随机推荐