Tensorflow 精度/召回率/F1 分数和混淆矩阵

2024-01-07

我想知道是否有一种方法可以实现 scikit learn 包中的不同分数函数,如下所示:

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)

进入张量流模型以获得不同的分数。

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
init = tf.initialize_all_variables()
sess.run(init)
for epoch in xrange(1):
        avg_cost = 0.
        total_batch = len(train_arrays) / batch_size
        for batch in range(total_batch):
                train_step.run(feed_dict = {x: train_arrays, y: train_labels})
                avg_cost += sess.run(cost, feed_dict={x: train_arrays, y: train_labels})/total_batch
        if epoch % display_step == 0:
                print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)

print "Optimization Finished!"
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy:", batch, accuracy.eval({x: test_arrays, y: test_labels})

我是否必须再次运行会话才能获得预测?


您实际上并不需要 sklearn 来计算精度/召回率/f1 分数。您可以通过查看以下公式轻松地以 TF 式的方式表达它们:

现在如果你有你的actual and predicted值作为 0/1 的向量,您可以使用以下方法计算 TP、TN、FP、FNtf.count_nonzero https://www.tensorflow.org/api_docs/python/tf/count_nonzero:

TP = tf.count_nonzero(predicted * actual)
TN = tf.count_nonzero((predicted - 1) * (actual - 1))
FP = tf.count_nonzero(predicted * (actual - 1))
FN = tf.count_nonzero((predicted - 1) * actual)

现在您的指标很容易计算:

precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 精度/召回率/F1 分数和混淆矩阵 的相关文章

  • 通过 Scrapy 抓取 Google Analytics

    我一直在尝试使用 Scrapy 从 Google Analytics 获取一些数据 尽管我是一个完全的 Python 新手 但我已经取得了一些进展 我现在可以通过 Scrapy 登录 Google Analytics 但我需要发出 AJAX
  • SQLAlchemy 通过关联对象声明式多对多自连接

    我有一个用户表和一个朋友表 它将用户映射到其他用户 因为每个用户可以有很多朋友 这个关系显然是对称的 如果用户A是用户B的朋友 那么用户B也是用户A的朋友 我只存储这个关系一次 除了两个用户 ID 之外 Friends 表还有其他字段 因此
  • 将 Matplotlib 误差线放置在不位于条形中心的位置

    我正在 Matplotlib 中生成带有错误栏的堆积条形图 不幸的是 某些层相对较小且数据多样 因此多个层的错误条可能重叠 从而使它们难以或无法读取 Example 有没有办法设置每个误差条的位置 即沿 x 轴移动它 以便重叠的线显示在彼此
  • Django:按钮链接

    我是一名 Django 新手用户 尝试创建一个按钮 单击该按钮会链接到我网站中的另一个页面 我尝试了一些不同的例子 但似乎没有一个对我有用 举个例子 为什么这不起作用
  • 如何使用Conda下载python包并随后离线安装?

    我知道通过 pip 我可以使用以下命令下载 Python 包 但 pip install 破坏了我的内部包依赖关系 当我做 pip download
  • 从字符串中删除识别的日期

    作为输入 我有几个包含不同格式日期的字符串 例如 彼得在16 45 我的生日是1990年7月8日 On 7 月 11 日星期六我会回家 I use dateutil parser parse识别字符串中的日期 在下一步中 我想从字符串中删除
  • 是否可以忽略一行的pyright检查?

    我需要忽略一行的pyright 检查 有什么特别的评论吗 def create slog group SLogGroup data Optional dict None SLog insert one SLog group group da
  • SQLALchemy .query:类“Car”的未解析属性引用“query”

    我有一个这里已经提到的问题https youtrack jetbrains com issue PY 44557 https youtrack jetbrains com issue PY 44557 但我还没有找到解决方案 我使用 Pyt
  • 以编程方式停止Python脚本的执行? [复制]

    这个问题在这里已经有答案了 是否可以使用命令在任意行停止执行 python 脚本 Like some code quit quit at this point some more code that s not executed sys e
  • 如何加速Python中的N维区间树?

    考虑以下问题 给定一组n间隔和一组m浮点数 对于每个浮点数 确定包含该浮点数的区间子集 这个问题已经通过构建一个解决区间树 https en wikipedia org wiki Interval tree 或称为范围树或线段树 已经针对一
  • 绘制方程

    我正在尝试创建一个函数 它将绘制我告诉它的任何公式 import numpy as np import matplotlib pyplot as plt def graph formula x range x np array x rang
  • 如何在ipywidget按钮中显示全文?

    我正在创建一个ipywidget带有一些文本的按钮 但按钮中未显示全文 我使用的代码如下 import ipywidgets as widgets from IPython display import display button wid
  • 如何使用Python创建历史时间线

    So I ve seen a few answers on here that helped a bit but my dataset is larger than the ones that have been answered prev
  • 无法在 Python 3 中导入 cProfile

    我试图将 cProfile 模块导入 Python 3 3 0 但出现以下错误 Traceback most recent call last File
  • Jupyter Notebook 内核一直很忙

    我已经安装了 anaconda 并且 python 在 Spyder IPython 等中工作正常 但是我无法运行 python 笔记本 内核被创建 它也连接 但它始终显示黑圈忙碌符号 防火墙或防病毒软件没有问题 我尝试过禁用两者 我也无法
  • Fabric env.roledefs 未按预期运行

    On the 面料网站 http docs fabfile org en 1 10 usage execution html 给出这个例子 from fabric api import env env roledefs web hosts
  • Conda SafetyError:文件大小不正确

    使用创建 Conda 环境时conda create n env name python 3 6 我收到以下警告 Preparing transaction done Verifying transaction SafetyError Th
  • Python:如何将列表列表的元素转换为无向图?

    我有一个程序 可以检索 PubMed 出版物列表 并希望构建一个共同作者图 这意味着对于每篇文章 我想将每个作者 如果尚未存在 添加为顶点 并添加无向边 或增加每个合著者之间的权重 我设法编写了第一个程序 该程序检索每个出版物的作者列表 并
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • NotImplementedError:无法将符号张量 (lstm_2/strided_slice:0) 转换为 numpy 数组。时间

    张量流版本 2 3 1 numpy 版本 1 20 在代码下面 define model model Sequential model add LSTM 50 activation relu input shape n steps n fe

随机推荐

  • MVC 查看可为空的日期字段格式

    我试图在视图中显示以下内容 但出现问题 td item CreatedByDt ToString MM dd yyyy td 关于如何处理视图中可为空的日期字段的任何想法 顺便说一句 我正在使用 Razor 我收到以下错误 方法 ToStr
  • 减少 solr 结果输出中类似的顶部结果

    我在 solr 中进行了一次搜索 返回了大约 1500 个文档 这些文档基本上都是产品 例如 我的数据集中有一堆女鞋 我的数据集有各种各样的女鞋 但也有一些非常相似的结果 例如 11 号女式耐克运动鞋 10 号女式耐克运动鞋等 现在 当我搜
  • PDFBOX - 使用 easytable 的所有页面中的页眉

    我正在使用 pdfbox 和 easytablehttps github com vandeseer easytable https github com vandeseer easytable用于创建效果很好的动态页面 但我确实希望在所有
  • Matplotlib 多条动画多行

    我一直在研究如何为飞行路径制作多条线的动画 我读取多个 GPS 文件的对象是时间同步它们 它们相对于时间为每条路径设置动画 我找到了如何在动画函数中使用附加来为一行添加动画 现在我需要添加第二个和第三个 以便导入尽可能多的文件 我知道问题出
  • 无法在有关 iron lib 的 fn 项目中捕获动态环境

    我使用c c 驱动的cassandra来查询 然后返回数据 因此 cass LinkedList 和cass it Vec 都可以显示查询的结果 但是 我想使用json格式将结果显示到web上 所以我选择使用vec重新组装数据 然而 有一个
  • 使用并行 NetCDF 保存分布式 3D 复杂数组

    我有一个用 Fortran 编写的基于 MPI 的程序 它在每个节点 2D 时间序列的部分 生成复杂数据的 3D 数组 我想使用并行 I O 将这些数组写入单个文件 该文件可以相对轻松地在 python 中打开以进行进一步分析 可视化 理想
  • 如果我从服务层公开 IQueryable,那么当我需要从多个服务获取信息时,数据库调用不是会减少吗?

    如果我从服务层公开 IQueryable 那么当我需要从多个服务获取信息时 数据库调用不是会减少吗 例如 我想在一个页面上显示 2 个单独的列表 Posts and Users 我有两个单独的服务提供这些服务的列表 如果两者都提供 IQue
  • 在 Emacs 中编译程序?

    在 emacs 中编译程序的最佳方法是什么 我目前正在打开一个单独的缓冲区C x 3并在其中运行 eshell 使用M x eshell然后直接调用 make 或 clang 大多数时候我确实设置了 Makefile 使用运行编译过程有什么
  • Apache Tiles 替代品

    我正在编写一个 Spring MVC 应用程序 并寻找一种在视图中进行布局的方法 我看到的唯一选择是 Apache Tiles 我以前使用过它并且知道维护其配置是多么痛苦 有什么好的选择吗 我在看SiteMesh http www site
  • Gitlab CI 如何使用规则语法忽略目录?

    我能够使用以下语法忽略目录 文件更改 build script npm run build except changes md src ts 有了这个配置build作业将运行 除非 git 更改仅包含 md扩展文件或 ts文件在src目录
  • requiredFieldValidator 不适用于下拉列表

    我有一个Dropdownlist在我的网页中如下
  • 使用 python range 对象索引 numpy 数组

    我以前见过它一两次 但我似乎找不到任何关于它的官方文档 Using pythonrange对象作为 numpy 中的索引 import numpy as np a np arange 9 reshape 3 3 a range 3 rang
  • 从 Woocommerce 3.4+ 中“我的帐户编辑地址”字段中删除(可选)文本

    我正在尝试删除 span class optional optional span 从 WooCommerce 我的帐户编辑地址页面 还有其他方法可以做到这一点吗 optional display none 我认为最好将其从表单中的 DOM
  • QTimer 随着每次启动/停止而变得更快

    我正在使用一个QTimer平滑地改变标签的大小 当我将鼠标悬停在按钮上时 它应该慢慢增大 当鼠标离开按钮时 它应该慢慢折叠 减小它的大小直到消失 我的表单类中有两个计时器 QTimer oTimer cTimer oTimer for ex
  • .NET BCL 中的跟踪与调试

    看来 System Diagnostics Debug https msdn microsoft com en us library system diagnostics debug v vs 110 aspx and System Dia
  • Docker 中的 Mariadb:MariaDB Connector/Python 需要 MariaDB Connector/C >= 3.2.4,发现版本 3.1.16

    我尝试以下 Dockerfile syntax docker dockerfile 1 FROM python 3 11 slim bullseye EXPOSE 80 WORKDIR app RUN apt get update apt
  • 热键、快捷键和加速键有什么区别?

    他们有什么区别呢 在Qt中 如果我有QPushButton的热键 我可以通过 Alt 来实现 但如果是qaction 我可以按 仅有的 In Windows an accelerator key is application global
  • 使用python调整两个字符串之间的1个空格

    我有两个字符串 gt gt gt a abcd gt gt gt b xyz gt gt gt c a b gt gt gt c abcdxyz 我怎样才能得到abcd xyz结果添加时a and b 只需在两个字符串之间添加一个空格即可
  • 什么样的工作受益于 OpenCL

    首先 我很清楚 OpenCL 并没有神奇地让一切变得更快 我很清楚 OpenCL 有局限性 现在回答我的问题 我习惯使用编程进行不同的科学计算 我处理的一些事情在计算的复杂性和数量方面非常激烈 所以我想知道 也许我可以使用 OpenCL 来
  • Tensorflow 精度/召回率/F1 分数和混淆矩阵

    我想知道是否有一种方法可以实现 scikit learn 包中的不同分数函数 如下所示 from sklearn metrics import confusion matrix confusion matrix y true y pred