该函数返回 2 个操作:
auc, update_op = tf.metrics.auc(...)
如果你跑sess.run(auc)
您将返回当前的 auc 值。这是您要报告的值,例如,print sess.run([auc, cost], feed_dict={...})
.
AUC 指标可能需要通过多次调用来计算sess.run
。例如,当您计算 AUC 的数据集不适合内存时。那就是update_op
进来。你需要每次调用它来累积计算所需的值auc
.
因此,在测试集评估期间,您可能会遇到以下情况:
for i in range(num_batches):
sess.run([accuracy, cost, update_op], feed_dict={...})
print("Final (accumulated) AUC value):", sess.run(auc))
当您想要重置累积值时(例如,在重新评估测试集之前),您应该重新初始化局部变量。这tf.metrics
包明智地将其累加器变量添加到局部变量集合中,默认情况下不包括可训练变量,例如权重。
sess.run(tf.local_variables_initializer()) # Resets AUC accumulator variables
https://www.tensorflow.org/api_docs/python/tf/metrics/auc https://www.tensorflow.org/api_docs/python/tf/metrics/auc