Caffe 中的预测 - 异常:输入 blob 参数与网络输入不匹配

2024-04-22

我使用 Caffe 使用非常简单的 CNN 结构对非图像数据进行分类。我在尺寸为 n x 1 x 156 x 12 的 HDF5 数据上训练网络没有任何问题。但是,我在对新数据进行分类时遇到了困难。

如何在不进行任何预处理的情况下进行简单的前向传播?我的数据已经标准化并且具有正确的 Caffe 维度(它已经用于训练网络)。下面是我的代码和 CNN 结构。

EDIT:我已将问题隔离到 pycaffe.py 中的函数“_Net_forward”,并发现问题是由于 self.input 字典为空而出现的。谁能解释这是为什么吗?该集合应该等于来自新测试数据的集合:

if set(kwargs.keys()) != set(self.inputs):
            raise Exception('Input blob arguments do not match net inputs.')

我的代码发生了一些变化,因为我现在使用 IO 方法将数据转换为数据(见下文)。这样我就用正确的数据填充了 kwargs 变量。

即使是很小的提示也将不胜感激!

    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt

    # Make sure that caffe is on the python path:
    caffe_root = ''  # this file is expected to be run from {caffe_root}
    import sys
    sys.path.insert(0, caffe_root + 'python')

    import caffe

    import os
    import subprocess
    import h5py
    import shutil
    import tempfile

    import sklearn
    import sklearn.datasets
    import sklearn.linear_model
    import skimage.io



    def LoadFromHDF5(dataset='test_reduced.h5', path='Bjarke/hdf5_classification/data/'):

        f   = h5py.File(path + dataset, 'r')
        dat = f['data'][:]
        f.close()   

        return dat;

    def runModelPython():
        model_file = 'Bjarke/hdf5_classification/conv_v2_simple.prototxt'
        pretrained = 'Bjarke/hdf5_classification/data/train_iter_10000.caffemodel'
        test_data = LoadFromHDF5()

        net = caffe.Net(model_file, pretrained)
        caffe.set_mode_cpu()
        caffe.set_phase_test()  

        user = test_data[0,:,:,:] 
        datum = caffe.io.array_to_datum(user.astype(np.uint8))
        user_dat = caffe.io.datum_to_array(datum)
        user_dat = user_dat.astype(np.uint8)
        out = net.forward_all(data=np.asarray([user_dat]))

if __name__ == '__main__':
    runModelPython()

CNN 原型

name: "CDR-CNN"
layers {
  name: "data"
  type: HDF5_DATA
  top: "data"
  top: "label"
  hdf5_data_param {
    source: "Bjarke/hdf5_classification/data/train.txt"
    batch_size: 10
  }
  include: { phase: TRAIN }
}
layers {
  name: "data"
  type: HDF5_DATA
  top: "data"
  top: "label"
  hdf5_data_param {
    source: "Bjarke/hdf5_classification/data/test.txt"
    batch_size: 10
  }
  include: { phase: TEST }
}

layers {
  name: "feature_conv"
  type: CONVOLUTION
  bottom: "data"
  top: "feature_conv"
  blobs_lr: 1
  blobs_lr: 2
  convolution_param {
    num_output: 10
    kernel_w: 12
    kernel_h: 1
    stride_w: 1
    stride_h: 1
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
    }
  }
}
layers {
  name: "conv1"
  type: CONVOLUTION
  bottom: "feature_conv"
  top: "conv1"
  blobs_lr: 1
  blobs_lr: 2
  convolution_param {
    num_output: 14
    kernel_w: 1
    kernel_h: 4
    stride_w: 1
    stride_h: 1
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
    }
  }
}
layers {
  name: "pool1"
  type: POOLING
  bottom: "conv1"
  top: "pool1"
  pooling_param {
    pool: MAX
    kernel_w: 1
    kernel_h: 3
    stride_w: 1
    stride_h: 3
  }
}
layers {
  name: "conv2"
  type: CONVOLUTION
  bottom: "pool1"
  top: "conv2"
  blobs_lr: 1
  blobs_lr: 2
  convolution_param {
    num_output: 120
    kernel_w: 1
    kernel_h: 5
    stride_w: 1
    stride_h: 1
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
    }
  }
}
layers {
  name: "fc1"
  type: INNER_PRODUCT
  bottom: "conv2"
  top: "fc1"
  blobs_lr: 1
  blobs_lr: 2
  weight_decay: 1
  weight_decay: 0
  inner_product_param {
    num_output: 84
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layers {
  name: "accuracy"
  type: ACCURACY
  bottom: "fc1"
  bottom: "label"
  top: "accuracy"
  include: { phase: TEST }
}
layers {
  name: "loss"
  type: SOFTMAX_LOSS
  bottom: "fc1"
  bottom: "label"
  top: "loss"
}

Here is 我在 Caffe Google Groups 上得到的 Evan Shelhamer 的答案 https://groups.google.com/forum/#!topic/caffe-users/aojN_bmbg74:

self._inputs确实用于定义的手动或“部署”输入 通过 prototxt 中的输入字段。运行带有数据层的网络 通过pycaffe,只需调用net.forward()没有争论。不需要 更改训练网或测试网的定义。

例如,请参见代码单元 [10]Python LeNet 示例 http://nbviewer.ipython.org/github/BVLC/caffe/blob/tutorial/examples/01-learning-lenet.ipynb.

事实上我认为它更清楚使用 Caffe 进行即时识别教程 https://github.com/BVLC/caffe/blob/master/examples/00-classification.ipynb,单元格 6:

# Feed in the image (with some preprocessing) and classify with a forward pass.
net.blobs['data'].data[...] = transformer.preprocess('data', caffe.io.load_image(caffe_root + 'examples/images/cat.jpg'))
out = net.forward()
print("Predicted class is #{}.".format(out['prob'].argmax()))

换句话说,要使用 pycaffe 生成预测输出及其概率,训练模型后,您必须首先向数据层提供输入,然后执行前向传递net.forward().


或者,正如其他答案中指出的那样,您可以使用部署原型,该原型与您用来定义训练网络的原型类似,但删除了输入和输出层,并在开头添加以下内容(显然根据您的输入进行调整)方面):

name: "your_net"
input: "data"
input_dim: 1
input_dim: 1
input_dim: 1
input_dim: 250

这就是他们在CIFAR10教程 https://github.com/BVLC/caffe/blob/master/examples/cifar10/cifar10_quick.prototxt.

(pycaffe really ought to be better documented…)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Caffe 中的预测 - 异常:输入 blob 参数与网络输入不匹配 的相关文章

  • Pandas:将增量数字添加到一列的重复值的后缀,这些重复值按另一列的值分组并按索引排序

    我试图将下划线和增量数字添加到按索引排序的任何重复值以及由另一列定义的组内 例如 我希望 化学 列中的重复值具有下划线和增量数字 并按索引排序并按 循环 列分组 df pd DataFrame 1 1 1 1 1 1 2 2 2 2 2 2
  • Erlang:到 Python 实例的端口没有响应

    我正在尝试通过 Erlang 端口与外部 python 进程进行通信 首先 打开一个端口 然后通过 stdin 将消息发送到外部进程 我期待在进程的标准输出上得到相应的答复 我的尝试如下所示 open a port Port open po
  • 在函数内的 for 循环上使用 tqdm 来检查进度

    我正在使用 for 循环迭代目录树内的一大组文件 这样做时 我想通过控制台中的进度条来监视进度 因此 我决定使用 tqdm 来实现此目的 目前 我的代码如下所示 for dirPath subdirList fileList in tqdm
  • App Engine 上的 Django 与 webapp2 [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 从内存地址创建python对象(使用gi.repository)

    有时我需要调用仅存在于 C 中的 gtk gobject 函数 但返回一个具有 python 包装器的对象 之前我使用过基于 ctypes 的解决方案 效果很好 现在我从 PyGtk import gtk 切换到 GObject intro
  • DynamodB:如何更新排序键?

    该表有两个键 filename 分区键 和eventTime 排序键 我要更新eventTime对于某些filename Tried put item and update item 发送相同的filename与新的eventTime但这些
  • 如何在 Pandas Python 中按 id 对行进行排名

    我有一个像这样的数据框 id points1 points2 1 44 53 1 76 34 1 63 66 2 23 34 2 44 56 我想要这样的输出 id points1 points2 points1 rank points2
  • Python:json_normalize pandas 系列给出 TypeError

    我在 pandas 系列中有数万行像这样的 json 片段df json IDs lotId 1 Id 123456 date 2009 04 17 bidsCount 2 IDs lotId 2 Id 123456 date 2009 0
  • 使用 Tkinter 打开网页

    因此 我的应用程序需要能够打开其中的单个网页 并且它必须来自互联网并且未保存 特别是我想使用 Tkinter GUI 工具包 因为它是我最熟悉的工具包 最重要的是 我希望能够在窗口中生成事件 例如单击鼠标 但无需实际使用鼠标 有什么好的方法
  • 会话数据库表清理

    该表是否需要清除或者由 Django 自动处理 Django 不提供自动清除功能 然而 有一个方便的命令可以帮助您手动完成此操作 Django 文档 清除会话存储 https docs djangoproject com en dev to
  • 在 Python 中从 Excel 复制 YEARFRAC() 函数

    因此 我使用 python 来自动执行一些必须在 Excel 中执行的重复任务 我需要做的计算之一需要使用yearfrac 这在Python中被复制了吗 I found this https lists oasis open org arc
  • 我可以用关闭的文件对象做什么?

    当您打开文件时 它存储在一个打开的文件对象中 该对象使您可以访问该文件的各种方法 例如读取或写入 gt gt gt f open file0 gt gt gt f
  • 无法通过 Android 应用程序访问我的笔记本电脑的本地主机

    因此 我在发布此内容之前做了一项研究 我发现的解决方案不起作用 更准确地说 连接到我的笔记本电脑的 IPv4192 168 XXX XXX 没用 连接到10 0 2 2 加上端口 不起作用 我需要测试使用 Django Rest 框架构建的
  • Pandas 字典键到列[重复]

    这个问题在这里已经有答案了 我有一个像这样的数据框 index column1 e1 u c680 5 u c681 1 u c682 2 u c57 e2 u c680 6 u c681 2 u c682 1 u c57 e3 u c68
  • 数据损坏 C++ 和 Python 之间的管道

    我正在编写一些代码 从 Python 获取二进制数据 将其通过管道传输到 C 对数据进行一些处理 在本例中计算互信息度量 然后将结果通过管道传输回 Python 在测试时 我发现如果我发送的数据是一组尺寸小于 1500 X 1500 的 2
  • 类返回语句不打印任何输出

    我正在学习课程 但遇到了问题return语句 它是语句吗 我希望如此 程序什么也没有打印出来 它只是结束而不做任何事情 class className def createName self name self name name def
  • 用 pandas DataFrame 替换 mysql 数据库表中的行

    Python 版本 2 7 6 熊猫版本 0 17 1 MySQLdb 版本 1 2 5 在我的数据库中 PRODUCT 我有一张桌子 XML FEED 表 XML FEED 很大 数百万条记录 我有一个 pandas DataFrame
  • 如何有效地比较 pandas DataFrame 中的行?

    我有一个 pandas 数据框 其中包含雷击记录以及时间戳和全球位置 格式如下 Index Date Time Lat Lon Good fix 0 1 20160101 00 00 00 9962692 7 1961 60 7604 1
  • py2exe ImportError:没有名为 的模块

    我已经实现了一个名为 myUtils 的包 它由文件夹 myUtils 文件 组成 init py 和许多名称为 myUtils 的 py 文件 该包包含在 myOtherProject py 中 当我从 Eclipse 运行它们时可以找到
  • Tkinter 将鼠标点击绑定到框架

    我一定错过了一些明显的东西 我的 Tkinter 程序中有两个框架 每个框架在网格布局中都有一堆标签 我想将鼠标点击绑定到其中一个而不是另一个 我目前使用 root bind

随机推荐

  • python 中单词的动名词形式

    我想获得字符串的动名词形式 我还没有找到调用库来获取动名词的直接方法 我应用了以 ing 结尾的单词的规则 但是因为异常导致我收到了一些错误 然后 我检查 cmu 单词以确保生成的动名词单词正确 代码如下 import cmudict im
  • 返回参数的类型名查找

    最近有一个学生问我一个编译问题 答案很简单 但现在我正在努力寻找原因 一个简单的例子 include
  • 获取 Jenkins 多分支管道中的分支列表

    Jenkins 多分支管道项目的 Blue Ocean 界面显示了自动创建的多个分支 是否有一种编程方式可以从要添加到 Jenkinsfile 的代码中列出管道中的分支 此问题询问位于 Jenkins 应用程序对象模型内的 Jenkins
  • 模板化成员函数的地址[重复]

    这个问题在这里已经有答案了 在下面的例子中 如何找到成员函数f的地址 template
  • 时间戳格式 - 从 1/1000 秒到 1/100 秒

    需要将1 1000秒分辨率的时间戳转换为1 100分辨率 我可能会用to char timestamp text 用于此目的的格式化功能 但是需要帮助text在这里使用 输入表 注意 这里的时间戳存储为 varchar ms1000 val
  • 复杂对象上的自定义 NSSortDescriptor

    这是我的第一篇文章 如果我可能不尊重所有惯例 我很抱歉 尽管我会尽力而为 我以前总是在 SO 上找到解决我的问题的方法 但我完全陷入了一个相当复杂的可可问题 我正在尝试对 CoreData 对象列表进行复杂的排序 我有一个由 Book 对象
  • Jquery:停止传播?

    我已经添加了 stopPropagation 但是 我仍然连续出现两个弹出窗口 这比以前好多了 其中一个被单击的元素有 20 个弹出窗口 是否有更好的方法或者我错过了什么 top document ready function click
  • 如何在 Android 浏览器上阻止某些网址?

    如何在 Android 默认浏览器上阻止某些网址 网站 我想限制用户访问某些列入黑名单的网址 例如 如果我想阻止 Facebook 那么手机内置应用程序浏览器将无法访问此 Facebook 网站 您想通过让用户安装应用程序来阻止用户设备上的
  • 使用 shutdown 终止 Amazon EC2 实例

    我可以使用 API 命令终止 Amazon EC2 实例ec2 终止实例但我试图找出如何在登录到 EC2 实例本身时执行此操作 我试过了立即关闭 h但这只是 停止 实例 而没有完全终止它 有什么办法可以做到这一点吗 您可以在创建实例时设置一
  • tkinter 无法正确识别屏幕分辨率

    我使用的是 4k 显示器 3840x2160 from tkinter import root Tk width root winfo screenwidth height root winfo screenheight print wid
  • MVC DropDownList SelectedValue 未正确显示

    我尝试搜索 但没有找到任何可以解决我的问题的内容 我在 Razor 视图上有一个 DropDownList 它不会显示我在 SelectList 中标记为 选定 的项目 以下是填充列表的控制器代码 var statuses new Sele
  • 密码强度计

    我正在尝试创建自己的 JS 密码强度计 它之前可以工作 但我不喜欢它的工作方式 所以我尝试使用 score 10 而不仅仅是 score 这是我的代码 http jsfiddle net RSq4L http jsfiddle net RS
  • 为什么 Kotlin 编译器需要 var 属性的显式初始化器?

    我无法理解以下 Kotlin 文档 The initializer getter and setter are optional Property type is optional if it can be inferred from th
  • 无法读取未定义的属性“forEach”

    var funcs 1 2 forEach i gt funcs push gt i 为什么会产生下面的错误 TypeError Cannot read property forEach of undefined at Object
  • 如何为 JVectorMap jquery 插件生成新的自定义地图?

    有用的链接 JVectorMap http jvectormap com http jvectormap com 购物中心示例 http jvectormap com examples mall http jvectormap com ex
  • Scala 中 def 和 val 的区别

    循环定义如下 def loop Boolean loop 当x定义为 def x loop然后控制台中会显示 x Boolean and 当x定义为 val x loop然后就进入无限循环 我知道 def 正在使用按名称调用 而 val 正
  • 不允许主机连接到此 MySQL 服务器以进行客户端-服务器应用程序

    我刚刚将表从一台 Web 主机导出到另一台 AWS 以为一切都会顺利 是的 没错 好吧 一切可能出错的事情都已经出错了 尝试查询我的数据库时出现此错误 我之前没有得到过 SQLSTATE HY000 1130 Host
  • Android:如何使用单个按钮执行多个任务

    我有 1 个按钮处于活动状态 我想使用这个 1 按钮来执行多项任务 那么我该怎么办呢 如果我第一次按此按钮 则更改 2 次按钮 如果我按第二次 它就会更新我的数据 但这只是第一次工作第二次就不起作用了 查看我的代码我尝试了什么 Intent
  • 更好的数据库设计是:更多的表还是更多的列?

    一位前同事坚持认为 具有更多表且每个列较少的数据库比具有较少表且每个列较多的数据库更好 例如 您将拥有一个名称表 一个地址表 一个城市表等 而不是包含名称 地址 城市 州 邮政编码等列的客户表 他认为这种设计更加高效和灵活 也许它更灵活 但
  • Caffe 中的预测 - 异常:输入 blob 参数与网络输入不匹配

    我使用 Caffe 使用非常简单的 CNN 结构对非图像数据进行分类 我在尺寸为 n x 1 x 156 x 12 的 HDF5 数据上训练网络没有任何问题 但是 我在对新数据进行分类时遇到了困难 如何在不进行任何预处理的情况下进行简单的前