无法在 Keras 中使用 VGG19 预测单个图像的标签

2023-12-01

我正在根据[本教程](使用迁移学习方法在 Keras 中使用经过训练的 VGG19 模型https://towardsdatascience.com/keras-transfer-learning-for-beginners-6c9b8b7143e)。它展示了如何训练模型,但不展示如何为预测准备测试图像。

评论区说:

获取图像,使用相同的方法对图像进行预处理preprocess_image函数,并调用model.predict(image)。这将为您提供该图像上模型的预测。使用argmax(prediction),您可以找到图像所属的类。

我找不到名为的函数preprocess_image在代码中使用。我做了一些搜索并考虑使用提出的方法本教程.

但这给出了一个错误:

decode_predictions expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 12)

我的数据集有 12 个类别。这是训练模型的完整代码以及我如何得到这个错误:

import pandas as pd
import numpy as np
import os
import keras
import matplotlib.pyplot as plt

from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.optimizers import Adam

base_model = VGG19(weights='imagenet', include_top=False)

x=base_model.output                                                          
x=GlobalAveragePooling2D()(x)                                                
x=Dense(1024,activation='relu')(x)                                           
x=Dense(1024,activation='relu')(x)                                           
x=Dense(512,activation='relu')(x)        

preds=Dense(12,activation='softmax')(x)                                      
model=Model(inputs=base_model.input,outputs=preds)                           

# view the layer architecture
# for i,layer in enumerate(model.layers):
#   print(i,layer.name)

for layer in model.layers:
    layer.trainable=False

for layer in model.layers[:20]:
    layer.trainable=False

for layer in model.layers[20:]:
    layer.trainable=True

train_datagen=ImageDataGenerator(preprocessing_function=preprocess_input)

train_generator=train_datagen.flow_from_directory('dataset',
                    target_size=(96,96), # 224, 224
                    color_mode='rgb',
                    batch_size=64,
                    class_mode='categorical',
                    shuffle=True)

model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])

step_size_train=train_generator.n//train_generator.batch_size

model.fit_generator(generator=train_generator,
    steps_per_epoch=step_size_train,
    epochs=5)

# model.predict(new_image)

IPython:

In [3]: import classify_tl                                                                                                                                                   
Found 4750 images belonging to 12 classes.
Epoch 1/5
74/74 [==============================] - 583s 8s/step - loss: 2.0113 - acc: 0.4557
Epoch 2/5
74/74 [==============================] - 576s 8s/step - loss: 0.8222 - acc: 0.7170
Epoch 3/5
74/74 [==============================] - 563s 8s/step - loss: 0.5875 - acc: 0.7929
Epoch 4/5
74/74 [==============================] - 585s 8s/step - loss: 0.3897 - acc: 0.8627
Epoch 5/5
74/74 [==============================] - 610s 8s/step - loss: 0.2689 - acc: 0.9071

In [6]: model = classify_tl.model                                                                                                                                            

In [7]: print(model)                                                                                                                                                         
<keras.engine.training.Model object at 0x7fb3ad988518>

In [8]: from keras.preprocessing.image import load_img                                                                                                                       

In [9]: image = load_img('examples/0021e90e4.png', target_size=(96,96))                                                                                                      

In [10]: from keras.preprocessing.image import img_to_array                                                                                                                  

In [11]: image = img_to_array(image)                                                                                                                                         

In [12]: image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))                                                                                          

In [13]: from keras.applications.vgg19 import preprocess_input                                                                                                               

In [14]: image = preprocess_input(image)                                                                                                                                     

In [15]: yhat = model.predict(image)                                                                                                                                         

In [16]: print(yhat)                                                                                                                                                         
[[1.3975363e-06 3.1069856e-05 9.9680350e-05 1.7175063e-03 6.2767825e-08
  2.6133494e-03 7.2859187e-08 6.0187017e-07 2.0794137e-06 1.3714411e-03
  9.9416250e-01 2.6067207e-07]]

In [17]: from keras.applications.vgg19 import decode_predictions                                                                                                             

In [18]: label = decode_predictions(yhat) 

IPython 提示符中的最后一行导致以下错误:

ValueError: `decode_predictions` expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 12)

我应该如何正确地输入我的测试图像并获得预测?


decode_predictions用于根据具有 1000 个类的 ImageNet 数据集中的类标签来解码模型的预测。然而,您的微调模型只有 12 个类。因此,使用没有意义decode_predictions这里。当然,您必须知道这 12 个类别的标签是什么。因此,只需取预测中最大得分的索引并找到其标签:

# create a list containing the class labels
class_labels = ['class1', 'class2', 'class3', ...., 'class12']

# find the index of the class with maximum score
pred = np.argmax(class_labels, axis=-1)

# print the label of the class with maximum score
print(class_labels[pred[0]])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

无法在 Keras 中使用 VGG19 预测单个图像的标签 的相关文章

  • 计算另一个字符串中多个字符串的出现次数

    在 Python 2 7 中 给定以下字符串 Spot是一只棕色的狗 斑点有棕色的头发 斑点的头发是棕色的 查找字符串中 Spot brown 和 hair 总数的最佳方法是什么 在示例中 它将返回 8 我正在寻找类似的东西string c
  • 在 Django Admin 中调整字段大小

    在管理上添加或编辑条目时 Django 倾向于填充水平空间 但在某些情况下 当编辑 8 个字符宽的日期字段或 6 或 8 个字符的 CharField 时 这确实是一种空间浪费 字符宽 然后编辑框最多可容纳 15 或 20 个字符 我如何告
  • 为什么 web2py 在启动时崩溃?

    我正在尝试让 web2py 在 Ubuntu 机器上运行 所有文档似乎都表明要在 nix 系统上运行它 您需要下载源代码并执行以下操作 蟒蛇 web2py py 我抓住了source http www web2py com examples
  • MongoEngine 查询具有以列表中指定的前缀开头的属性的对象的列表

    我需要在 Mongo 数据库中查询具有以列表中任何前缀开头的特定属性的元素 现在我有一段这样的代码 query mymodel terms term in query terms 并且这会匹配在列表 term 上有一个项目的对象 该列表中的
  • 打印数字时添加千位分隔符[重复]

    这个问题在这里已经有答案了 我真的不知道这个问题的 名称 所以它可能是一个不正确的标题 但问题很简单 如果我有一个数字 例如 number 23543 second 68471243 我想要它使print 像这样 23 54368 471
  • 矩形函数的数值傅里叶变换

    本文的目的是通过一个众所周知的分析傅里叶变换示例来正确理解 Python 或 Matlab 上的数值傅里叶变换 为此 我选择矩形函数 这里报告了它的解析表达式及其傅立叶变换https en wikipedia org wiki Rectan
  • 打印包含字符串和其他 2 个变量的变量

    var a 8 var b 3 var c hello my name is var a and var b bye print var c 当我运行程序时 var c 会像这样打印出来 hello my name is 8 and 3 b
  • 使用 Python Oauthlib 通过服务帐户验证 Google API

    我不想使用适用于 Python 的 Google API 客户端库 但仍想使用 Python 访问 Google APIOauthlib https github com idan oauthlib 创建服务帐户后谷歌开发者控制台 http
  • 导入错误:没有名为flask.ext.login的模块

    我的flask login 模块有问题 我已经成功安装了flask login模块 另外 从命令提示符我可以轻松运行此脚本 不会出现错误 Python 2 7 r27 82525 Jul 4 2010 07 43 08 MSC v 1500
  • 嵌套作用域和 Lambda

    def funct x 4 action lambda n x n return action x funct print x 2 prints 16 我不太明白为什么2会自动分配给n n是返回的匿名函数的参数funct 完全等价的定义fu
  • Python - 如何确定解析的 XML 元素的层次结构级别?

    我正在尝试使用 Python 解析 XML 文件中具有特定标记的元素并生成输出 excel 文档 该文档将包含元素并保留其层次结构 我的问题是我无法弄清楚每个元素 解析器在其上迭代 的嵌套深度 XML 示例摘录 3 个元素 它们可以任意嵌套
  • 如何使用 Python 3 检查目录是否包含文件

    我到处寻找这个答案但找不到 我正在尝试编写一个脚本来搜索特定的子文件夹 然后检查它是否包含任何文件 如果包含 则写出该文件夹的路径 我已经弄清楚了子文件夹搜索部分 但检查文件却难倒了我 我发现了有关如何检查文件夹是否为空的多个建议 并且我尝
  • 找到一个数字所属的一组范围

    我有一个 200k 行的数字范围列表 例如开始位置 停止位置 该列表包括除了非重叠的重叠之外的所有类型的重叠 列表看起来像这样 3 5 10 30 15 25 5 15 25 35 我需要找到给定数字所属的范围 并对 100k 个数字重复该
  • Protobuf 如何编码 oneof 消息结构

    对于这个 python 程序 在编码时运行 protobuf 编码会给出以下输出 0a 10 08 7f8a 0104 08 02 10 0392 0104 08 02 10 03 18 01 我不明白的是为什么8a后面有一个01 为什么9
  • 在 Google App Engine 中,如何避免创建具有相同属性的重复实体?

    我正在尝试添加一个事务 以避免创建具有相同属性的两个实体 在我的应用程序中 每次看到新的 Google 用户登录时 我都会创建一个新的播放器 当新的 Google 用户在几毫秒内进行多个 json 调用时 我当前的实现偶尔会创建重复的播放器
  • 带有 LSTM 的 GridSearchCV/RandomizedSearchCV

    我一直在尝试通过 RandomizedSearchCV 调整 LSTM 的超参数 我的代码如下 X train X train reshape X train shape 0 1 X train shape 1 X test X test
  • python 中的“槽包装器”是什么?

    object dict 和其他地方的隐藏方法设置为这样的
  • 如何使用 PrimaryKeyRelatedField 更新多对多关系上的类别

    Django Rest 框架有一个主键相关字段 http www django rest framework org api guide relations primarykeyrelatedfield其中列出了我的 IDmany to m
  • pytest找不到模块[重复]

    这个问题在这里已经有答案了 我正在关注pytest 良好实践 https docs pytest org en latest explanation goodpractices html test discovery或者至少我认为我是 但是
  • NLTK:查找单词大小为 2k 的上下文

    我有一个语料库 我有一个词 对于语料库中该单词的每次出现 我想获取一个包含该单词之前的 k 个单词和该单词之后的 k 个单词的列表 我在算法上做得很好 见下文 但我想知道 NLTK 是否提供了一些我错过的功能来满足我的需求 def size

随机推荐

  • MKOverlay 视图模糊

    我正在尝试使用 MKOverlayView 添加 png 图像作为自定义地图 我几乎就在那里 我能够将图像排列在正确的位置 并且我知道 MKOverlayView 子类中的 drawMapRect 方法正在被定期调用 我似乎无法正确渲染图像
  • cmd:将 wmic 输出保存到变量

    我正在尝试将文件的时间戳放入批处理文件中的变量中 我的批处理文件 imagetime bat 包含以下内容 set targetfile 1 set targetfile targetfile echo targetfile for f u
  • Windows Phone 8 LongListSelector 内的图像内存泄漏

    我有一个 LongListSelector 其中包含一个图像控件 该控件从网络加载大量图像 这在一段时间内工作正常 但在加载一些图像后 我出现内存不足异常 我读到其他人也有同样的问题 涉及大量图像内存不足 但仍然没有找到解决方案 我读到它与
  • Firebase 部署错误:构建失败:npm ERR!代码EUSAGE

    我目前正在使用 Firebase 托管 但突然遇到部署错误 我尝试了所有可能的解决方案 但部署错误多次出现 所以现在需要社区的帮助 请帮我解决这个问题 从这条线上部署过程失败 i functions updating Node js 16
  • 主源文件中的结构值未更新

    我的项目的一部分 一些源文件是button key h button key h lcd h mani c etc 在按钮 key H 中使用了一个结构并声明为 struct menu uint8 t Hour uint8 t Minute
  • 拦截列表总体以在反序列化中分配值

    我有一个递归类 树层次结构 它派生自一个列表 该列表具有子级及其自身 通过 JSON NET 中的反序列化填充 TLDR 版本是 我想在该类存在的每个级别上从父级填充子级中的变量 而不使用 JSON NET 中的 ref 变量 存储到 co
  • 如何从R中UNC指定的目录中读取文件?

    是否可以从 UNC 指定的目录中读取文件R 我想在不使用基本安装之外的任何软件包的情况下完成此操作 UNC 名称工作正常 您只需正确转义它们即可 这对我有用 read csv COMPUTER Directory file txt
  • 我如何在 QMake 中包含 python.h

    INCLUDEPATH L usr include python2 7 LIBS usr local lib python2 7 QMAKE CXXFLAGS usr local lib python2 7 error cannot fin
  • C++程序与MySQL数据库通信

    有谁知道 C 程序直接与 MySQL 数据库通信的简单方法吗 我查看了 MySQL 发现它非常令人困惑 如果有人知道一个非常简单的方法 请告诉我 Thanks 附 我正在 Windows 机器上进行开发 PHP 和 MySQL Web 应用
  • LINQ Intersect 不返回项目

    我已经为我的自定义类实现了一个比较类 以便我可以使用Intersect在两个列表中 StudentList1 and StudentList2 但是 当我运行以下代码时 我没有得到任何结果 Student class CompareStud
  • awk 要求合并两个文件

    我通过 AWK 命令使用 Same Key 组合了两个不同的文件 如果与 File1 和 File2 相比没有关键匹配 则只需 把 t t t 代替 我有以下 AWK 命令 awk F t key 1 NR 1 header key key
  • 为什么不能将 Dictionary> 转换为 Dictionary>?

    我想知道为什么我不能直接进行强制转换 我有一个模糊的想法 这可能与协 逆变的东西有关 我是否被迫按顺序将第一个字典的元素复制到新的字典中得到我想要的类型 你不能这样做 因为它们不是同一类型 考虑 var x new Dictionary
  • 不使用模式名称访问表

    我是 DB2 新手 如果不使用架构名称 我无法从表中获取数据 如果我使用带有表名的模式名称 我就可以获取数据 Example SELECT FROM TABLE NAME 它给了我错误 同时 SELECT FROM SCHEMA NAME
  • 如何在Python OpenCV中获取轮廓的x,y位置

    我试图从下图中获取轮廓的 x 和 y 位置 但我搞砸了 图片 我只需要找到轮廓的 x 和 y 位置或轮廓的中心 当我从 GIMP 手动查找它们的位置时 结果将类似于以下内容 290 210 982 190 570 478 我相信可以用 cv
  • 从 C# 调用非托管 C++ 库 (dll) 会产生访问冲突错误 (0xc0000005)

    抱歉问了这么长的问题 我只是想包括我目前所知道的有关该问题的所有信息 我正在使用 Visual Studio 2008 用 C 创建一个 Windows 窗体程序 该程序调用用 C 编写的库 C DLL 分析由多个样本组成的测量数据 样本通
  • 使用php脚本将多个doc或rtf文件合并为一个doc或rtf文件

    我想将多个 doc 或 rtf 文件合并到一个文件中 该文件应该与多个文件的格式相同 我的意思是 如果用户从列表框中选择多个 rtf 模板文件并单击网页上的按钮 则输出应该是组合多个 rtf 模板文件的单个 rtf 文件 我应该使用 php
  • 使用 TypeScript 将箭头函数分配给泛型函数类型

    我已经对类似问题进行了一些挖掘 但找不到有效的解决方案 我有一些类型的通用函数 但我似乎无法正确实现它们 简而言之 我有这个 Takes three values of the same type and collapses them in
  • Visual Studio 2010 无法加载导入了 元素的项目

    我们有一个大型 约 800 个单独的项目 系统 我们正在将其从旧的构建系统迁移到 Visual Studio 2010 在过去的几周里 我们为每个项目手动创建了 Visual Studio 项目文件 vcxproj 格式 我们可以仅使用 M
  • UnreachableBrowserException:无法启动新会话。可能的原因是 Selenium Grid 远程服务器的地址无效

    打开新驱动程序窗口时出错 org openqa selenium remote UnreachableBrowserException Could not start a new session Possible causes are in
  • 无法在 Keras 中使用 VGG19 预测单个图像的标签

    我正在根据 本教程 使用迁移学习方法在 Keras 中使用经过训练的 VGG19 模型https towardsdatascience com keras transfer learning for beginners 6c9b8b7143