Tensorflow:我的准确性出现问题

2023-12-05

我只是运行一个简单的代码,并希望在训练后获得准确性。我加载了保存的模型,但是当我想要获得准确性时,却出现了问题。为什么?

# coding=utf-8
from  color_1 import read_and_decode, get_batch, get_test_batch
import AlexNet
import cv2
import os
import time
import numpy as np
import tensorflow as tf
import AlexNet_train
import math

batch_size=128
num_examples = 1000
crop_size=56

def evaluate(test_x, test_y):
    image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
    label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')

    y = AlexNet.inference(image_holder,evaluate,None)

    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        sess.run(init_op)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        ckpt=tf.train.get_checkpoint_state(AlexNet_train.MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            saver.restore(sess, os.path.join(AlexNet_train.MODEL_SAVE_PATH, ckpt_name))
            print('Loading success, global_step is %s' % global_step)
            step=0

            image_batch, label_batch = sess.run([test_x, test_y])
            accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,
                                                              label_holder: label_batch})
            print("After %s training step(s),validation "
                  "precision=%g" % (global_step, accuracy_score))
        coord.request_stop()  
        coord.join(threads)

def main(argv=None):
    test_image, test_label = read_and_decode('val.tfrecords')

    test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)

    evaluate(test_images, test_labels)


if __name__=='__main__':
    tf.app.run()

这是错误,它说我的代码中的这一行是错误的:“ Correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))”

Traceback (most recent call last):
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
    tf.app.run()
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
    evaluate(test_images, test_labels)
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 45, in evaluate
    label_holder: label_batch})
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1
     [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]

Caused by op u'ArgMax_1', defined at:
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
    tf.app.run()
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
    evaluate(test_images, test_labels)
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 22, in evaluate
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 263, in argmax
    return gen_math_ops.arg_max(input, axis, name)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 168, in arg_max
    name=name)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1
     [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]

怎么解决呢?


参与这个答案与这里的问题相关:

tf.argmax的定义 states:

轴:张量。必须是以下类型之一:int32、int64。 int32,0 。描述输入的轴 减少交叉的张量。

那么看来,唯一的办法就是逃跑argmax在张量的最后一个轴上是通过给它axis=-1,因为函数定义中的“严格小于”符号。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:我的准确性出现问题 的相关文章

  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • Pycharm Python 控制台不打印输出

    我有一个从 Pycharm python 控制台调用的函数 但没有显示输出 In 2 def problem1 6 for i in range 1 101 2 print i end In 3 problem1 6 In 4 另一方面 像
  • 导入错误:没有名为 _ssl 的模块

    带 Python 2 7 的 Ubuntu Maverick 我不知道如何解决以下导入错误 gt gt gt import ssl Traceback most recent call last File
  • 如何打印没有类型的defaultdict变量?

    在下面的代码中 from collections import defaultdict confusion proba dict defaultdict float for i in xrange 10 confusion proba di
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • 如何等到 Excel 计算公式后再继续 win32com

    我有一个 win32com Python 脚本 它将多个 Excel 文件合并到电子表格中并将其另存为 PDF 现在的工作原理是输出几乎都是 NAME 因为文件是在计算 Excel 文件内容之前输出的 这可能需要一分钟 如何强制工作簿计算值
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • 在循环中每次迭代开始时将变量重新分配给原始值(在循环之前定义)

    在Python中 你使用 在每次迭代开始时将变量重新分配给原始值 在循环之前定义 时 也就是说 original 1D o o o for i in range 0 3 new original 1D revert back to orig
  • 使用 Pycharm 在 Windows 下启动应用程序时出现 UnicodeDecodeError

    问题是当我尝试启动应用程序 app py 时 我收到以下错误 UnicodeDecodeError utf 8 编解码器无法解码位置 5 中的字节 0xb3 起始字节无效 整个文件app py coding utf 8 from flask
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • feedparser 在脚本运行期间失败,但无法在交互式 python 控制台中重现

    当我运行 eclipse 或在 iPython 中运行脚本时 它失败了 ascii codec can t decode byte 0xe2 in position 32 ordinal not in range 128 我不知道为什么 但
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • HTTPS 代理不适用于 Python 的 requests 模块

    我对 Python 还很陌生 我一直在使用他们的 requests 模块作为 PHP 的 cURL 库的替代品 我的代码如下 import requests import json import os import urllib impor
  • 循环中断打破tqdm

    下面的简单代码使用tqdm https github com tqdm tqdm在循环迭代时显示进度条 import tqdm for f in tqdm tqdm range 100000000 if f gt 100000000 4 b
  • Python - 在窗口最小化或隐藏时使用 pywinauto 控制窗口

    我正在尝试做的事情 我正在尝试使用 pywinauto 在 python 中创建一个脚本 以在后台自动安装 notepad 隐藏或最小化 notepad 只是一个示例 因为我将编辑它以与其他软件一起使用 Problem 问题是我想在安装程序
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 如何从没有结尾的管道中读取 python 中的 stdin

    当管道来自 打开 时 不知道正确的名称 我无法从 python 中的标准输入或管道读取数据 文件 我有作为例子管道测试 py import sys import time k 0 try for line in sys stdin k k
  • 用于运行可执行文件的python多线程进程

    我正在尝试将一个在 Windows 上运行可执行文件并管理文本输出文件的 python 脚本升级到使用多线程进程的版本 以便我可以利用多个核心 我有四个独立版本的可执行文件 每个线程都知道要访问它们 这部分工作正常 我遇到问题的地方是当它们
  • 在python中,如何仅搜索所选子字符串之前的一个单词

    给定文本文件中的长行列表 我只想返回紧邻其前面的子字符串 例如单词狗 描述狗的单词 例如 假设有这些行包含狗 hotdog big dog is dogged dog spy with my dog brown dogs 在这种情况下 期望
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo

随机推荐

  • 为什么colspan影响html表格边框

    所以我偶然发现了一些对我来说似乎很奇怪的东西 例如 以下代码 table tr td align center style border 3px solid black Title td tr tr td style border 2px
  • 如何检查jframe是否打开?

    我下面的代码创建一个新数组并将其发送到聊天 jFrame String info1 new String 3 username userid userid2 are variables info1 0 username4 info1 1 u
  • 如何更改JsRender模板标签?

    我用树枝 它使用这些标签 name 我想将 JsRender 包含在我的项目中 但 JsRender 也使用相同的标签 name 所以存在冲突并且没有任何作用 如何使用自定义标签更改默认的 JsRender 标签 类似于 Ruby UPD
  • PyQT 布局之间的导航

    下面是我的应用程序代码 它允许您在窗口之间切换 该菜单有两个编程选项 例如 详细报告 和 所有公司 现在加载布局后 我不知道如何将按钮放在这两个视图中 以允许您将视图从 详细报告 更改为 全部 公司 反之亦然 你能帮助我吗 class Ap
  • Stackdriver 日志记录中不会创建任何日志

    在我的谷歌应用程序脚本中 我有 Logger log test 我什至尝试过 console log test 但即使我将项目 id 设置为 Google Cloud 项目 id 也不会打印到 stackdriver 日志中 屏幕显示 为了
  • .NET 中将 int 转换为位数组

    如何将 int 转换为位数组 如果我例如有一个值为 3 的 int 我想要一个长度为 8 的数组 如下所示 0 0 0 0 0 0 1 1 这些数字中的每一个都位于数组中大小为 8 的单独槽中 Use the BitArray class
  • MySQL 多数据库设置

    我已经寻找了这个问题的答案 我似乎能找到的只是一些问题 询问是使用多个数据库还是在单个数据库中使用多个表更好 但这不是我的问题 问题 1 我想在当前数据库旁边设置一个新数据库 但不知道如何操作 我想授予用户对 DB2 的完全管理员访问权限
  • VBA:检测用户窗体的任何文本框中的更改

    有一个用户表单有很多文本框 我需要检测每个文本框的更改 因此 我为表单中的每个文本框编写了一个子例程 结果是一大段代码 由于每个文本框的代码都是相同的 我想优化它 那么是否可以只编写一个子例程来检测表单的任何文本框中的更改 实现这一目标的唯
  • 在32位系统上安装64位glib2进行交叉编译

    我正在尝试在 32 位 ubuntu 系统上交叉编译 64 位可执行文件 这一直有效 直到链接为止 由于缺少 64 位 glib2 libglib 2 0 a 它失败了 如果我在 64 位系统上执行此操作 我会使用getlibs将 32 位
  • OpenJPA 合并/持久非常慢

    我在 WebSphere Application Server 8 上使用 OpenJPA 2 2 0 和 MySQL 5 0 DB 我有一个要合并到数据库中的对象列表 就像是 for Object ob list Long start C
  • Redis:插入元素在开头还是结尾时,ZADD 是否比 O(logN) 更好?

    雷迪斯文档对于 ZADD 来说 操作是 O logN 然而 有谁知道 ZADD 是否比 O logN 当插入的元素位于排序顺序的开头或结尾时 例如 对于某些实现 这可能是 O 1 具体来说 redistutorial指出 排序集是通过双端口
  • 如何获取 Win32 中可用串行端口的列表?

    我有一些遗留代码 通过调用提供 PC 上可用 COM 端口的列表EnumPorts 函数 然后过滤以 COM 开头的端口名称 出于测试目的 如果我可以将此代码与类似的东西一起使用 那将非常有用com0com 它提供了成对的虚拟 COM 端口
  • Typescript 模块和 systemjs。从内联脚本实例化类

    我正在使用系统模块选项将 typescript 模块转换为 javascript 我正在浏览器中执行此操作 当初始化由 typescript 生成的模块类的代码也使用 systemjs system import 加载时 我可以使用此模块
  • td 中的多行

    Stores td 包含多行表 一个商店可以有多个 商店 行 参见示例 https jsfiddle net ak3wtkak 1 商店宽度和数量 th 第二个表中的多行列应相同 如何解决这个问题或者什么是替代方法 table border
  • 处理带有 Promise 的对象数组

    我正在尝试制作一个 Node Express 应用程序 在其中从不同的 url 获取数据 调用 node fetch 来提取某些页面的正文以及有关某些 url 端点的其他信息 然后我想渲染一个 html 表格来通过信息数组显示这些数据 我在
  • LINQ 中的更新查询包含 WHERE 子句中的所有列,而不仅仅是主键列

    我正在使用 Linq 更新表中的单个列 请使用下面的虚构表格 MyTable PKID ColumnToUpdate SomeRandomColumn var row from x in DataContext MyTable where
  • Android studio 在真实设备上运行应用程序后添加了不需要的权限

    在设备上运行应用程序后 应用程序需要清单文件中未提及的不需要的位置权限 当我从我的朋友 Android studio 运行相同的代码时 它运行正常 不需要额外的许可 清单文件
  • 更简洁的最大/最小版本,没有块

    我通常这样做 abc defg max a b a length lt gt b length 但这似乎需要大量额外的输入来比较两个对象上相同方法的结果 有没有更简洁的方法来做类似的事情 abc defg max length 哪个会在每个
  • 嵌套选择器 - 可能吗?

    假设我有一个div里面有一堆东西 div ul ul div class Component div div
  • Tensorflow:我的准确性出现问题

    我只是运行一个简单的代码 并希望在训练后获得准确性 我加载了保存的模型 但是当我想要获得准确性时 却出现了问题 为什么 coding utf 8 from color 1 import read and decode get batch g