如何获取 Tensorflow 2.0 中的其他指标(不仅仅是准确性)?

2023-12-30

我是 Tensorflow 领域的新手,正在研究 mnist 数据集分类的简单示例。我想知道除了准确性和损失(并可能显示它们)之外,如何获得其他指标(例如精度、召回率等)。这是我的代码:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
import os 

#load mnist dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

#create and compile the model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)), 
  tf.keras.layers.Dense(128, activation='relu'), 
  tf.keras.layers.Dropout(0.2), 
  tf.keras.layers.Dense(10, activation='softmax') 
])
model.summary()

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#model checkpoint (only if there is an improvement)

checkpoint_path = "logs/weights-improvement-{epoch:02d}-{accuracy:.2f}.hdf5"

cp_callback = ModelCheckpoint(checkpoint_path, monitor='accuracy',save_best_only=True,verbose=1, mode='max')

#Tensorboard
NAME = "tensorboard_{}".format(int(time.time())) #name of the model with timestamp
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))

#train the model
model.fit(x_train, y_train, callbacks = [cp_callback, tensorboard], epochs=5)

#evaluate the model
model.evaluate(x_test,  y_test, verbose=2)

由于我只得到准确性和损失,我怎样才能得到其他指标? 提前谢谢您,如果这是一个简单的问题或者如果已经在某处得到回答,我很抱歉。


我添加另一个答案,因为这是在测试集上正确计算这些指标的最简洁的方法(截至 2020 年 3 月 22 日)。

您需要做的第一件事是创建一个自定义回调,在其中发送测试数据:

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import classification_report 

class MetricsCallback(Callback):
    def __init__(self, test_data, y_true):
        # Should be the label encoding of your classes
        self.y_true = y_true
        self.test_data = test_data
        
    def on_epoch_end(self, epoch, logs=None):
        # Here we get the probabilities
        y_pred = self.model.predict(self.test_data))
        # Here we get the actual classes
        y_pred = tf.argmax(y_pred,axis=1)
        # Actual dictionary
        report_dictionary = classification_report(self.y_true, y_pred, output_dict = True)
        # Only printing the report
        print(classification_report(self.y_true,y_pred,output_dict=False)              
           

在您的 main 中,加载数据集并添加回调:

metrics_callback = MetricsCallback(test_data = my_test_data, y_true = my_y_true)
...
...
#train the model
model.fit(x_train, y_train, callbacks = [cp_callback, metrics_callback,tensorboard], epochs=5)

         
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何获取 Tensorflow 2.0 中的其他指标(不仅仅是准确性)? 的相关文章

  • Python - 定义常量列表或字典的最佳/最简洁的方法

    第一次使用堆栈溢出 我很高兴来到这里 简介 我最近开始了 Python 编程世界的神奇冒险 我喜欢它 现在 在我从 C 语言的尴尬过渡中 一切都进展顺利 但我在创建与标头文件 h 同义的内容时遇到了麻烦 问题 我有中等大小的字典和列表 大约
  • 如何在 Django Admin 的“更改”页面中显示内嵌上传的图像?

    我正在尝试在中显示内联上传的图像 变更列表 页面在 Django 管理中 这是我的代码如下 models py from django db import models class Product models Model name mod
  • pandas python 根据一个或多个其他列的子集更新 A 列的子集

    Edit我修改了下面的部分描述 以澄清 功能 和 组 的含义 修复拼写错误 并包含我尝试过的其他代码 我的熊猫df有 450 万行和 23 列 下表显示了几行df2这是从生成的df 它显示了两组 eeskin and hduquant 和三
  • Ubuntu Python shebang 线不工作

    无法让 shebang 线在 Ubuntu 中为 python 脚本工作 我每次只收到命令未找到错误 test py usr bin env python print Ran which python usr bin python 在 sh
  • PyQt4 信号和槽

    我正在使用 PyQt4 编写我的第一个 Python 应用程序 我有一个 MainWindow 和一个 Dialog 类 它是 MainWindow 类的一部分 self loginDialog LoginDialog 我使用插槽和信号 这
  • 尽管 ioff() 和 matplotlib.use('Agg'),Pyplot“无法连接到 X 服务器 localhost:10.0”

    我有一段代码 它被不同的函数调用 为我执行一些计算 然后将输出绘制到文件中 鉴于整个脚本可能需要一段时间才能运行更大的数据集 并且由于我可能想在给定时间分析多个数据集 所以我开始它screen然后断开连接并关闭我的腻子会话 并在第二天再检查
  • Django REST Framework:无法使用视图名称解析超链接关系的 URL

    我已经广泛研究了这个相当常见的问题 但没有一个修复对我有用 我正在 REST 框架中构建 Django 项目 并希望使用超链接关系 用户可以拥有许多独立的汽车和路线 路线是位置的集合 这些是我的序列化器 class CarSerialize
  • 不使用 graphviz/web 可视化决策树

    由于某些限制 我无法使用 graphviz webgraphviz com 可视化决策树 工作网络与另一个世界是封闭的 问题 是否有一些替代实用程序或一些 Python 代码用于至少非常简单的可视化可能只是决策树的 ASCII 可视化 py
  • Python控制台默认十六进制显示

    我在 Python 控制台中做了很多工作 其中大部分都涉及地址 我更喜欢以十六进制形式查看地址 So if a 0xBADF00D 当我简单地输入Python gt a进入控制台查看其值 我更喜欢 python 回复0xBADF00D代替1
  • Pythoncom - 将相同的 COM 对象传递给多个线程

    你好 对于 COM 对象 我是一个完全的初学者 非常感谢任何帮助 我正在开发一个Python程序 该程序应该以客户端 服务器的方式读取传入的MS Word文档 即客户端发送一个请求 一个或多个MS Word文档 服务器使用pythoncom
  • 计算二维笛卡尔坐标中不规则形状的边界

    我正在寻找一种计算不规则形状边界的解决方案 Lats take a look at Square example 如果我有Minimum x and y and Maximum x and y like MaxX 5 MinX 1 MaxY
  • python 和 android 中通过 AES 算法加密和解密

    我有用于 AES 加密的 python 和 android 代码 当我在android中加密文本时 它在python上成功解密 但无法在android端解密 有人有想法吗 Python代码 import base64 import hash
  • 使用 Matplotlib、PyQt 和 Threading 进行实时绘图导致 python 崩溃

    我一直在努力研究我的 Python 应用程序 但找不到任何答案 我有 PyQT GUI 应用程序 它使用 Matplotlib 小部件 GUI 启动一个新线程来处理 mpl 小部件的绘图 恐怕我现在通过从另一个线程访问 matplotlib
  • Celery 设计帮助:如何防止并发执行任务

    我对 Celery AMQP 相当陌生 正在尝试提出一个任务 队列 工作人员设计来满足以下要求 我有多种类型的 每用户 任务 例如 TaskA TaskB TaskC 这些 每用户 任务中的每一个都为系统中的一个特定用户读取 写入数据 因此
  • 如何在 Python 中包含 PHP 脚本?

    我有一个 PHP 脚本 news generator php 当我包含它时 它会抓取一堆新闻项并打印它们 现在 我在我的网站 CGI 中使用 Python 当我使用 PHP 时 我在 新闻 页面上使用了这样的内容 为了简单起见 我删掉了这个
  • Python:如何使用生成器来避免 sql 内存问题

    我有以下方法来访问 mysql 数据库 并且查询在服务器中执行 我无权更改有关增加内存的任何内容 我对生成器很陌生 并开始阅读更多有关它的内容 并认为我可以将其转换为使用生成器 def getUNames self globalUserQu
  • 如何在 Flask 中获取 POSTed JSON?

    我正在尝试使用 Flask 构建一个简单的 API 现在我想在其中读取一些 POSTed JSON 我使用 Postman Chrome 扩展进行 POST 我 POST 的 JSON 很简单 text lalala 我尝试使用以下方法读取
  • Django:在单独的线程中使用相同的测试数据库

    我正在使用具有以下数据库设置的测试数据库运行 pytests DATABASES default ENGINE django db backends postgresql psycopg2 NAME postgres USER someth
  • 使用 Tweepy 获取推文时出错

    我有一个用于获取推文的 Python 脚本 在脚本中我使用该库 Tweepy 我使用有效的身份验证参数 运行此脚本后 一些推文存储在我的 MongoDB 中 有些则被 if 语句拒绝 但我仍然收到错误 requests packages u
  • 有效积累稀疏 scipy 矩阵的集合

    我有一个 O N NxN 的集合scipy sparse csr matrix 每个稀疏矩阵都有 N 个元素集 我想将所有这些矩阵加在一起以获得一个常规的 NxN numpy 数组 N 约为 1000 矩阵内非零元素的排列使得所得总和肯定不

随机推荐