如何以 HDF5 格式提供 caffe 多标签数据?

2024-05-03

我想将 caffe 与矢量标签一起使用,而不是整数。我检查了一些答案,似乎 HDF5 是更好的方法。但后来我陷入了这样的错误:

precision_layer.cpp:34] 检查失败:outer_num_ * inner_num_ == bottom[1]->count()(50 vs. 200)标签数量必须与预测数量匹配;例如,如果标签轴 == 1 并且预测形状为 (N, C, H, W),则标签计数(标签数量)必须为N*H*W,整数值为 {0, 1, ..., C-1}。

HDF5 创建为:

f = h5py.File('train.h5', 'w')
f.create_dataset('data', (1200, 128), dtype='f8')
f.create_dataset('label', (1200, 4), dtype='f4')

我的网络是由以下内容生成的:

def net(hdf5, batch_size):
    n = caffe.NetSpec()
    n.data, n.label = L.HDF5Data(batch_size=batch_size, source=hdf5, ntop=2)
    n.ip1 = L.InnerProduct(n.data, num_output=50, weight_filler=dict(type='xavier'))
    n.relu1 = L.ReLU(n.ip1, in_place=True)
    n.ip2 = L.InnerProduct(n.relu1, num_output=50, weight_filler=dict(type='xavier'))
    n.relu2 = L.ReLU(n.ip2, in_place=True)
    n.ip3 = L.InnerProduct(n.relu1, num_output=4, weight_filler=dict(type='xavier'))
    n.accuracy = L.Accuracy(n.ip3, n.label)
    n.loss = L.SoftmaxWithLoss(n.ip3, n.label)
    return n.to_proto()

with open(PROJECT_HOME + 'auto_train.prototxt', 'w') as f:
f.write(str(net('/home/romulus/code/project/train.h5list', 50)))

with open(PROJECT_HOME + 'auto_test.prototxt', 'w') as f:
f.write(str(net('/home/romulus/code/project/test.h5list', 20)))

看来我应该增加标签数量并将内容放入整数而不是数组中,但如果我这样做,caffe 会抱怨数据数量和标签不相等,然后存在。

那么,输入多标签数据的正确格式是什么?

另外,我很想知道为什么没有人简单地编写 HDF5 映射到 caffe blob 的数据格式?


回答这个问题的标题:

HDF5 文件的根目录中应有两个数据集,分别命名为“data”和“label”。形状是(data amount, dimension)。我只使用一维数据,所以我不确定顺序是什么channel, width, and height。也许这并不重要。dtype应该是浮动或双精度。

创建训练集的示例代码h5py is:



import h5py, os
import numpy as np

f = h5py.File('train.h5', 'w')
# 1200 data, each is a 128-dim vector
f.create_dataset('data', (1200, 128), dtype='f8')
# Data's labels, each is a 4-dim vector
f.create_dataset('label', (1200, 4), dtype='f4')

# Fill in something with fixed pattern
# Regularize values to between 0 and 1, or SigmoidCrossEntropyLoss will not work
for i in range(1200):
    a = np.empty(128)
    if i % 4 == 0:
        for j in range(128):
            a[j] = j / 128.0;
        l = [1,0,0,0]
    elif i % 4 == 1:
        for j in range(128):
            a[j] = (128 - j) / 128.0;
        l = [1,0,1,0]
    elif i % 4 == 2:
        for j in range(128):
            a[j] = (j % 6) / 128.0;
        l = [0,1,1,0]
    elif i % 4 == 3:
        for j in range(128):
            a[j] = (j % 4) * 4 / 128.0;
        l = [1,0,1,1]
    f['data'][i] = a
    f['label'][i] = l

f.close()
  

此外,不需要精度层,只需将其删除即可。下一个问题是损失层。自从SoftmaxWithLoss只有一个输出(具有最大值的维度的索引),它不能用于多标签问题。感谢 Adian 和 Shai,我发现SigmoidCrossEntropyLoss在这种情况下很好。

下面是完整的代码,从数据创建、训练网络到获取测试结果:

main.py(根据caffelanet示例修改)



import os, sys

PROJECT_HOME = '.../project/'
CAFFE_HOME = '.../caffe/'
os.chdir(PROJECT_HOME)

sys.path.insert(0, CAFFE_HOME + 'caffe/python')
import caffe, h5py

from pylab import *
from caffe import layers as L

def net(hdf5, batch_size):
    n = caffe.NetSpec()
    n.data, n.label = L.HDF5Data(batch_size=batch_size, source=hdf5, ntop=2)
    n.ip1 = L.InnerProduct(n.data, num_output=50, weight_filler=dict(type='xavier'))
    n.relu1 = L.ReLU(n.ip1, in_place=True)
    n.ip2 = L.InnerProduct(n.relu1, num_output=50, weight_filler=dict(type='xavier'))
    n.relu2 = L.ReLU(n.ip2, in_place=True)
    n.ip3 = L.InnerProduct(n.relu2, num_output=4, weight_filler=dict(type='xavier'))
    n.loss = L.SigmoidCrossEntropyLoss(n.ip3, n.label)
    return n.to_proto()

with open(PROJECT_HOME + 'auto_train.prototxt', 'w') as f:
    f.write(str(net(PROJECT_HOME + 'train.h5list', 50)))
with open(PROJECT_HOME + 'auto_test.prototxt', 'w') as f:
    f.write(str(net(PROJECT_HOME + 'test.h5list', 20)))

caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver(PROJECT_HOME + 'auto_solver.prototxt')

solver.net.forward()
solver.test_nets[0].forward()
solver.step(1)

niter = 200
test_interval = 10
train_loss = zeros(niter)
test_acc = zeros(int(np.ceil(niter * 1.0 / test_interval)))
print len(test_acc)
output = zeros((niter, 8, 4))

# The main solver loop
for it in range(niter):
    solver.step(1)  # SGD by Caffe
    train_loss[it] = solver.net.blobs['loss'].data
    solver.test_nets[0].forward(start='data')
    output[it] = solver.test_nets[0].blobs['ip3'].data[:8]

    if it % test_interval == 0:
        print 'Iteration', it, 'testing...'
        correct = 0
        data = solver.test_nets[0].blobs['ip3'].data
        label = solver.test_nets[0].blobs['label'].data
        for test_it in range(100):
            solver.test_nets[0].forward()
            # Positive values map to label 1, while negative values map to label 0
            for i in range(len(data)):
                for j in range(len(data[i])):
                    if data[i][j] > 0 and label[i][j] == 1:
                        correct += 1
                    elif data[i][j] %lt;= 0 and label[i][j] == 0:
                        correct += 1
        test_acc[int(it / test_interval)] = correct * 1.0 / (len(data) * len(data[0]) * 100)

# Train and test done, outputing convege graph
_, ax1 = subplots()
ax2 = ax1.twinx()
ax1.plot(arange(niter), train_loss)
ax2.plot(test_interval * arange(len(test_acc)), test_acc, 'r')
ax1.set_xlabel('iteration')
ax1.set_ylabel('train loss')
ax2.set_ylabel('test accuracy')
_.savefig('converge.png')

# Check the result of last batch
print solver.test_nets[0].blobs['ip3'].data
print solver.test_nets[0].blobs['label'].data
  

h5list 文件仅包含每行中 h5 文件的路径:

火车.h5list

/home/foo/bar/project/train.h5

测试.h5列表

/home/foo/bar/project/test.h5

和求解器:

auto_solver.prototxt


train_net: "auto_train.prototxt"
test_net: "auto_test.prototxt"
test_iter: 10
test_interval: 20
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "sed"
solver_mode: GPU
  

Converge graph: Converge graph

上一批结果:



[[ 35.91593933 -37.46276474 -6.2579031 -6.30313492]
[ 42.69248581 -43.00864792 13.19664764 -3.35134125]
[ -1.36403108 1.38531208 2.77786589 -0.34310576]
[ 2.91686511 -2.88944006 4.34043217 0.32656598]
...
[ 35.91593933 -37.46276474 -6.2579031 -6.30313492]
[ 42.69248581 -43.00864792 13.19664764 -3.35134125]
[ -1.36403108 1.38531208 2.77786589 -0.34310576]
[ 2.91686511 -2.88944006 4.34043217 0.32656598]]

[[ 1. 0. 0. 0.]
[ 1. 0. 1. 0.]
[ 0. 1. 1. 0.]
[ 1. 0. 1. 1.]
...
[ 1. 0. 0. 0.]
[ 1. 0. 1. 0.]
[ 0. 1. 1. 0.]
[ 1. 0. 1. 1.]]
  

我认为这段代码还有很多地方需要改进。任何建议表示赞赏。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何以 HDF5 格式提供 caffe 多标签数据? 的相关文章

  • 在 Django 中定义视图和 url。为什么调用函数时不使用括号?

    我已经在经历 Python速成课程 目前正在进行 Django Web应用程序项目 学习日志 阶段 有些东西与我已经学到的相矛盾 views py file from django shortcuts import render def i
  • 使用 python 制作本地服务器应用程序的最佳方法

    我想要简单轻松地集成 python 和 vba 人们 如果他们在阅读本文后亲自见到我 阅读本文可能会杀了我 但我正在使用 django 开发服务器来实现此目的 有没有什么简单又好的方法 仅举个例子 我想使用 python 模块 openpy
  • python 可以检测它运行在哪个操作系统下吗?

    python 可以检测操作系统 然后为文件系统构建 if else 语句吗 我需要将 Fn 字符串中的 C CobaltRCX 替换为 FileSys 字符串 import os path csv from time import strf
  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • 在 Python 中使用 XPath 和 LXML

    我有一个 python 脚本 用于解析 XML 并将某些感兴趣的元素导出到 csv 文件中 我现在尝试更改脚本以允许根据条件过滤 XML 文件 等效的 XPath 查询将是 DC Events Confirmation contains T
  • 使用 Django 的 post_save() 信号

    我有两张桌子 class Advertisement models Model created at models DateTimeField auto now add True author email models EmailField
  • Dask DataFrame 的逐行处理

    我需要处理一个大文件并更改一些值 我想做这样的事情 for index row in dataFrame iterrows foo doSomeStuffWith row lol doOtherStuffWith row dataFrame
  • 类属性在功能上依赖于其他类属性

    我正在尝试使用静态类属性来定义另一个静态类属性 我认为可以通过以下代码来实现 f lambda s s 1 class A foo foo bar f A foo 然而 这导致NameError name A is not defined
  • NLTK、搭配问题:需要解包的值太多(预期为 2)

    我尝试使用 NLTK 检索搭配 但出现错误 我使用内置的古腾堡语料库 I wrote alice nltk corpus gutenberg fileids 7 al nltk corpus gutenberg words alice al
  • 反加入熊猫

    我有两个表 我想附加它们 以便仅保留表 A 中的所有数据 并且仅在其键唯一时添加表 B 中的数据 键值在表 A 和 B 中是唯一的 但在某些情况下键将出现在表 A 和 B 中 我认为执行此操作的方法将涉及某种过滤联接 反联接 以获取表 B
  • 如何为多组精灵创建随机位置?

    我尝试使用 blit 和 draw 方法进行 for 循环 并为 PlayerSprite 和 Treegroup 使用不同的变量 for PlayerSprite in Treegroup surface blit PlayerSprit
  • 使用Python将图像转换为十六进制格式

    我的下面有一个jpg文件tmp folder upload path tmp resized test jpg 我一直在使用下面的代码 Method 1 with open upload path rb as image file enco
  • 在 Mac 上安装 Pygame 到 Enthought 构建中

    关于在 Mac 上安装 Pygame 有许多未解答的问题 但我将在这里提出我的具体问题并希望得到答案 我在 Mac 上安装 Pygame 时遇到了难以置信的困难 我使用 Enthought 版本 EPD 7 3 2 32 位 它是我的默认框
  • 在 Windows 上使用 IPython 笔记本时出现 500 服务器错误

    我刚刚在 Windows 7 Professional 64 位上全新安装了 IPython 笔记本 我采取的步骤是 从以下位置安装 Python 3 4 1http python org http python org gt pip in
  • Python int 太大,无法放入 SQLite

    我收到错误 OverflowError Python int 太大 无法转换为 SQLite INTEGER 来自以下代码块 该文件约25GB 因此必须分部分读取 length 6128765 Works on partitions of
  • 是否可以写一个负的python类型注释

    这可能听起来不合理 但现在我需要否定类型注释 我的意思是这样的 an int Not Iterable a string Iterable 这是因为我为一个函数编写了一个重载 而 mypy 不理解我 我的功能看起来像这样 overload
  • Plotly:如何避免巨大的 html 文件大小

    我有一个 3D 装箱模型 它使用绘图来绘制输出图 我注意到 绘制了 600 个项目 生成 html 文件需要很长时间 文件大小为 89M 这太疯狂了 我怀疑可能存在一些巨大的重复 或者是由单个项目的 add trace 方法引起的 阴谋 为
  • Python模块单元测试的最佳文件结构组织?

    遗憾的是 我发现有太多方法可以在 Python 中保存单元测试 而且它们通常没有很好的文档记录 我正在寻找一种 终极 结构 它可以满足以下大部分要求 be discoverable by test frameworks including
  • Google App Engine 中的自定义身份验证

    有谁知道或知道我可以在哪里学习如何使用 Python 和 Google App Engine 创建自定义身份验证流程 我不想使用 Google 帐户进行身份验证 并且希望能够创建自己的用户 如果不是专门针对 Google App Engin
  • 如何识别图形线条

    我有以下格式的路径的 x y 数据 示例仅用于说明 seq p1 p2 0 20 2 3 1 20 2 4 2 20 4 4 3 22 5 5 4 22 5 6 5 23 6 2 6 23 6 3 7 23 6 4 每条路径都有多个点 它们

随机推荐