tensorflow mnist实战笔记(二)制作和读取自己的数据集

2023-10-30

这里面写的非常详细

http://www.itdadao.com/articles/c15a1401577p0.html

看了网上N多的教程,发现mnist的教程的数据都是官网已经制作好的,那么如果我们自己有数字图片,我们该怎么利用tensoeflow制作数据呢?

现在我有6万张训练集,1万张测试集,下载地址在这mnist图片数据下载:http://pan.baidu.com/s/1pLMV4Kz

首先我们需要有图片数据的txt表,以及对应的标签,如下所示,制作txt表在caffe中已经提到,传送门

mnist/train/5/00000.png 5
mnist/train/0/00001.png 0
mnist/train/4/00002.png 4
mnist/train/1/00003.png 1 下面这串代码就可以在原路径得到a.tfrecords文件

import numpy as np
import cv2
import tensorflow as tf

resize_height=28 #存储图片高度
resize_width=28 #存储图片宽度
train_file_root = '/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/mnist_img_data'
train_file = train_file_root+'/train.txt'     #trainfile是txt文件存放的目录

def _int64_feature(value):#将value转化成int64字节属性,
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):#将value转化成bytes属性
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def load_file(examples_list_file):
    # type: (object) -> object
    lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])
    examples = []
    labels = []
    for example, label in lines:
        examples.append(example)
        labels.append(label)
    return np.asarray(examples), np.asarray(labels), len(lines)
##load_file函数返回的examples,labels,lines,比如examples[0]指的是mnist/train/4/00002.png,也就是 txt的路径 labels[0]返回的是对应的label值是4

def extract_image(filename,  resize_height, resize_width):  #这边调用cv2.imread()来读取图像,由于cv2读取是BGR格式,需要转换成RGB格式
    image = cv2.imread(filename)
    image = cv2.resize(image, (resize_height, resize_width))
    b,g,r = cv2.split(image)
    rgb_image = cv2.merge([r,g,b])  # this is suitable
    rgb_image = rgb_image / 255.
    rgb_image = rgb_image.astype(np.float32)
    return rgb_image   

examples, labels, examples_num = load_file(train_file)
writer = tf.python_io.TFRecordWriter('/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords')
# root = train_file_root + '/' + examples[0]
for i, [example, label] in enumerate(zip(examples, labels)):
    print('No.%d' % (i))
    root = train_file_root + '/' + examples[i]
    image = extract_image(root, resize_height, resize_width)
    a = image.shape
    print(root)
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))
    image_raw = image.tostring() #将Image转化成字符
    example = tf.train.Example(features=tf.train.Features(feature={
        'image_raw': _bytes_feature(image_raw),
        'height': _int64_feature(image.shape[0]),
        'width': _int64_feature(image.shape[1]),
        'depth': _int64_feature(image.shape[2]),
        'label': _int64_feature(label)
    }))
    writer.write(example.SerializeToString())
writer.close()
上面代码最重要的是

 example = tf.train.Example(features=tf.train.Features(feature={
        'image_raw': _bytes_feature(image_raw),
        'height': _int64_feature(image.shape[0]),
        'width': _int64_feature(image.shape[1]),
        'depth': _int64_feature(image.shape[2]),
        'label': _int64_feature(label)
    }))
这一段,我们可以看出,a.tfrecords里面其实对应的是一些字典,比如Image_raw对应的是图像矩阵本身保存的字节文件,height则是则是对应的高,其实height什么的不写进去也没事,但label一定要写。

现在我们可以得到a.tfrecords这个文件,我们该怎么解析里面的内容呢?或者我们该怎么将tfrecords里面的二进制文件转换成我们可以可视化的数字图片呢

下面这串代码可以得出

import numpy as np

import matplotlib.pyplot as plt
import tensorflow as tf


tfrecord_list_file = '/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords'



def read_and_decode(filename_queue,shuffle_batch=True):

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialized_example, features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    })

    image = tf.decode_raw(features['image_raw'],  tf.float32)
    image = tf.reshape(image, [28, 28, 3])
    image = image * 255.0

    labels = features['label']

    if shuffle_batch:
        images, labels = tf.train.shuffle_batch(
            [image,labels],
            batch_size=4,
            capacity=8000,
            num_threads=4,
            min_after_dequeue=2000)
    else:
        images,labels = tf.train.batch([image,labels],
                   batch_size=4,
                   capacity=8000,
                   num_threads=4)
    return images,labels

def test_run(tfrecord_filename):
    filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)
    images,labs = read_and_decode(filename_queue)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
    # meanvalue = meanfile['mean']               #如果在制作数据时减去的均值,则需要加上来


    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(1):
            imgs,labs = sess.run([images,labs])
            print 'batch' + str(i) + ': '
            #print type(imgs[0])

            for j in range(4):
                print str(labs[j])
                img = np.uint8(imgs[j] )
                plt.subplot(4, 2, j * 2 + 1)
                plt.imshow(img)
            plt.show()

        coord.request_stop()
        coord.join(threads)     #注意,要关闭文件

test_run('/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords')
print ("has done")

主要用到tf.decode_raw,这个内置函数的意思是解析 tfrecords文件里的二进制数据,我的read_and_decode只返回图像和label,所以只需要用到tfrecords里面的image_raw和label

 image = tf.decode_raw(features['image_raw'],  tf.float32)  #解析image_raw数据,注意,tf.float32是数据类型,一定要和制作数据时用的类型一样
 image = tf.reshape(image, [28, 28, 3])
 image = image * 255.0    #我在制作数据时除了255,这里可以补回来或者不补

 labels = features['label'] #label则是对应的标签

目前了解的也就这么多,基本都是从其他博客整理得到的,下面是参考博客

cv2.imread()和caffe.io.loadimage的区别


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow mnist实战笔记(二)制作和读取自己的数据集 的相关文章

  • 计算机简史:从分布式到中心化的博弈螺旋

    作为应用程序开发商 我们有责任为用户的隐私和信息安全考虑 用户应该拥有控制自己信息数据的权利 这些用户数据应该在初始化的时候就被加密和保护 作者 Eric Elliott 书籍 函数式编程的兴衰 和 编程JavaScript应用 的作者 他
  • apt install命令安装失败(已解决)

    在安装virtualbox时出现如下错误 百度了许久没有找到答案 朋友指点下已解决 分享给你们 以上 在处理时有错误发生 getdeb repository N 忽略 getdeb list bck 于目录 etc apt sources
  • Transformers GitHub项目星标突破10万!新里程碑!

    点击下方卡片 关注 CVer 公众号 AI CV重磅干货 第一时间送达 点击进入 gt Transformer 微信技术交流群 转载自 新智元 编辑 桃子 导读 问世6年来 Transformer不仅成为NLP领域的主流模型 甚至成功向其他
  • c++遇到的问题剪辑

    1 VS中使用TextOutW 参数不匹配问题 BOOL CDC TextOutW int int const CString 不能将参数 3 从 const char 18 错误说明 VC程序中使用了文字输出函数 TextOut 20 2
  • centos8配置网络

    centos安装后配置网络连接 3 网络配置 3 1 查看自己主机ip 网关等信息 例如个人主机信息如下 使用ipconfig all 可以查看所有信息 包括DNS 3 2 设置vmware网络连接 vmware gt 我的计算机 gt 设
  • java.net.ProtocolException: exceeded content-length limit of XXX bytes

    场景 使用HttpURLConnection向服务器提交POST请求 模拟将评论的内容提交给服务器 url new URL http 10 0 2 2 9102 post comment HttpURLConnection connecti
  • Docker之一:账号注册

    Docker之一 账号注册 在学习Docker之前 我们需要在Docker官网注册一个账号 方便后续下载安装包等操作 在主页面点击 Sign in 注册 Sign Up 填写注册信息 然后点击 Sign Up 选择Free 如果您有需求 可
  • python读取excel文件并保存成array

    要使用xlrd包 import numpy as np import xlrd 读取excel的库 resArray 先声明一个空list data xlrd open workbook demo xlsx 读取文件 table data
  • 【SQL基础】SQL增删改查基本语句

    目录 1 SQL 增删改查基本语句 2 select 语句 2 1 select 基本语句 2 2 Select where 语句 2 3 Select order by 语句 2 4 Select group by 语句 3 Select
  • 求解决,ubuntu16.04 文件“桌面/文档/下载/图片…”切换为英文,出现无法打开的空链接文件,错误报告以及显示

    文章参考 ubuntu下面 将汉字桌面 下载 换成 英文 https blog csdn net tanhuanzheng article details 103557287 ubuntu16 04 vmware虚拟机 1 我们可以先将目录
  • Ubuntu 16.04配置国内高速apt-get更新源

    Ubuntu 16 04下载软件速度有点慢 因为默认的是从国外下载软件 那就更换到国内比较好的快速更新源 就是这些软件所在的服务器 一般直接百度Ubuntu更新源就能出来一大堆 这时候最好是找和自己Ubuntu版本一致的更新源 我的Ubun
  • MFC入门基础(九)消息对话框、文件对话框

    一 消息对话框主要是两种CWnd MessageBox 和AfxMessageBox 下面是在按钮点击事件中添加MessageBox的效果 如下 void Ctest02Dlg OnBnClickedAddButton TODO 在此添加控
  • form表单 input输入框及属性

  • 启动Hadoop时一直提示输入密码的问题(SSH配置)

    目录 1 首先检查自己是否有配置本地ssh免密登录 2 另外一种方式 科普 什么是SSH 情况如下图所示 经常弹出要要我输入password 1 首先检查自己是否有配置本地ssh免密登录 a 下载SSH服务 yum install open
  • js vue 使用 map和computed巧妙设计可选列表和已选列表的联动

    需求说明 当已选列表中存在了可选列表的选项 则在可选列表中做出标记 使用map和computed的巧妙写法 otherFiledList是已选数据 fieldList是可选数据 已选数据是可选数据构成的 div i class el ico
  • 16. Dubbo原理解析-集群&容错之router路由服务

    Router服务路由 根据路由规则从多个Invoker中选出一个子集AbstractDirectory是所有目录服务实现的上层抽象 它在list列举出所有invokers后 会在通过Router服务进行路由过滤 Router接口定义 pub
  • 2016——大数据版图

    编者注 原文是 FirstMark Capital 的 Matt Turck 的文章 本文全面总结了大数据领域的发展态势 分析认为尽管大数据作为一个术语似乎已经过气 但是大数据分析与应用才刚刚开始兴起 在与 AI 人工智能等新兴技术的结合下
  • JSON格式转MAP的6种方法

    JSON字符串自动转换 Created by zkn on 2016 8 22 public class JsonToMapTest01 public static void main String args String str 0 zh
  • MySQL中的各种自增ID

    微信搜索 coder home 或扫一扫下面的二维码 关注公众号 第一时间了解更多干货分享 还有各类视频教程资源 扫描它 带走我 文章目录 背景 自增ID的数据类型 单位换算规则 自增ID取值范围 无符号位的计算方式 有符号位的计算方式 i

随机推荐

  • JDialog弹窗

    JDialog弹窗 package com chen lesson4 import javax swing import java awt import java awt event ActionEvent import java awt
  • python后端学习(二)TCP客户端和服务端

    TCP简介 TCP协议 传输控制协议 英语 Transmission Control Protocol 缩写为 TCP 是一种面向连接的 可靠的 基于字节流的传输层通信协议 由IETF的RFC 793定义 TCP通信需要经过创建连接 数据传
  • 14 【接口规范和业务分层】

    14 接口规范和业务分层 1 接口规范 RESTful架构 1 1 什么是REST REST全称是Representational State Transfer 中文意思是表述 编者注 通常译为表征 性状态转移 它首次出现在2000年Roy
  • android EditText 实时监听输入框的内容

    在开发中很多时候我们都会用到EditText 对输入内容的实时监听也是不可或缺的 在android中为我们提供了TextWatcher这个类 我们只要extends这个类然后etColler addTextChangedListener e
  • C#基础知识框架整理

    目录 NET FrameWork框架 NET平台 类库 快速启动vs sln文件的使用 在解决方案里 csprog文件的使用 在项目文件夹里 排除语法错误 设置行号 设置字体 恢复出厂设置 自动切换运行的项目 C 的3种注释符 C 常用的快
  • 浙大计算机学院博士毕业论文要求,浙大在读博士需要3篇SCI 论文才能毕业,清华博士却不作要求!...

    原标题 浙大在读博士需要3篇SCI 论文才能毕业 清华博士却不作要求 最近 又进入了一年的秋招季 很多学子纷纷加入求职大军之中 但是今年却有不一样的声音 有在读研究生表示 学校对毕业要求提高 要在专业期刊发表论文 这成了比找工作更让人烦心的
  • Java整合七牛云进行文件的上传与下载

    一 七牛云的对象存储的介绍 七牛云对象存储 Kodo 是七牛云提供的高可靠 强安全 低成本 可扩展的存储服务 您可通过控制台 API SDK 等方式简单快速地接入七牛存储服务 实现海量数据的存储和管理 通过 Kodo 可以进行文件的上传 下
  • Filter与Listener

    目录 Filter 1 Filter概念 2 Filter快速入门 3 Filter细节 1 web xml配置 2 Filter执行流程 3 Filter生命周期方法 4 Filter配置详解 拦截路径配置 拦截方式配置 1 注解配置 2
  • micropython下载及安装编译过程

    本文根据 参考文献 实现基于Black F407VE开发板的micropython移植 为后期 stm32H743的 micropython作准备 参考 http docs micropython org en latest 1 下载mic
  • k8s 实战之路

    k8s就是kubernetes 关于k8s 基本属于运维的范畴 一般除了一线大厂会有自研的运维平台和运维团队去做这些事 其他的中小型公司都会要求自己的研发人员懂这些运维的东西 还有nginx等 k8s 刚接触 目前还没有在现实工作中实际操作
  • java 继承 异常_Java异常以及继承的一些问题

    Java异常以及继承的一些问题 类之间的关系 java异常类层次结构图 Throwable Throwable是 Java 语言中所有错误或异常的超类 Throwable包含两个子类 Error 和 Exception 它们通常用于指示发生
  • 【vue】el-upload 图片上传的封装

  • Android DataBinding的基本使用

    5 DataBinding https developer android com topic libraries data binding custom conversions 数据绑定库是一种支持库 借助该库 您可以使用声明性格式 而非
  • 基于pytorch的LSTM进行字符级文本生成实战

    基于pytorch的LSTM进行字符级文本生成实战 文章目录 基于pytorch的LSTM进行字符级文本生成实战 前言 一 数据集 二 代码实现 1 导入库及LSTM模型构建 2 数据预处理函数 3 训练函数 4 预测函数 5 文本生成函数
  • Searching the String 【ZOJ - 3228】【AC自动机+last跳板优化时间】

    题目链接 这次要求的有两个 分别是0 1 代表着的是可以重叠 以及不可以重叠的遍历到该单词的次数 可以重叠的很容易 遇到的时候 就直接加上就是了 但是不可以重叠的时候呢 就需要看到该单词它和上一次的状态出现的距离差了 看一下是否比这个单词长
  • Qt中使用QProcess调用第三方程序

    在Qt中调用第三方程序通常使用QProcess进行调用 以下描述QProcess常用的接口 1 QProcess startDetached QProcess startDetached const QString program cons
  • SQL将会员按照总消费金额从高到低分成50档。(分档、分组)

    面试题 交易表 有4个字段 订单号 会员id 消费金额 购买时间 问题 将会员按照总消费金额从高到低分成50档 解题步骤 1 解题思路 将某一个字段按区间分档 最先想到的是 猴子 从零学会SQL 里讲过的多条件语句 case when 但是
  • Python 计算机视觉(七)—— OpevCV进行直方图统计

    本文中涉及到的 matplotlib 绘图库的知识可以参考我的之前的文章 Python 绘图库 Matplotlib 目录 1 直方图概述 1 基本概念 2 直方图中的术语 BINS DIMS RANGE 2 直方图绘制 1 读取图像信息
  • 正交、独立、不相关区别

    一 三者的定义 假设X为一个随机过程 则在t1和t2时刻的随机变量的相关定义如下 两个随机过程一样 1 定义Rx t1 t2 E X t1 X t2 为相关函数 若R 0 称正交 注意 相关函数为0 不是不相关 而是正交 正交不仅仅描述确定
  • tensorflow mnist实战笔记(二)制作和读取自己的数据集

    这里面写的非常详细 http www itdadao com articles c15a1401577p0 html 看了网上N多的教程 发现mnist的教程的数据都是官网已经制作好的 那么如果我们自己有数字图片 我们该怎么利用tensoe