TensorFlow之tfrecords文件详细教程
制作数据集思路:
将训练数据和测试数据生成tfrecords文件 为什么呢? 这种文件以二进制进行存储,只占用一个内存块 对于大数据能够提高cpu效率
代码步骤
首先对数据进行处理 方便后面写入tfrecords文件
1:使用tf.python_io.TFRecordWriter('定义一个文件名')类 定义一个tfrecords 文件
2:将每一条样本数据按照相应的特征组织好,即将样本数据组织成Example的过程
3:将组织好的Example写入进tfrecords文件,并关闭tfrecords文件即可
import tensorflow as tf
import numpy as np
from PIL import Image
import os
data_path = './data'
"""
制作数据集思路:
将训练数据和测试数据生成tfrecords文件 为什么呢? 这种文件以二进制进行存储,只占用一个内存块 对于大数据能够提高cpu效率
代码步骤
首先对数据进行处理 方便后面写入tfrecords文件
1:使用tf.python_io.TFRecordWriter('定义一个文件名')类 定义一个tfrecords 文件
2:将每一条样本数据按照相应的特征组织好,即将样本数据组织成Example的过程
3:将组织好的Example写入进tfrecords文件,并关闭tfrecords文件即可
"""
def write_tfrecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName)
num_pic = 0
f = open(label_path, 'r')
contents = f.readline()
f.close()
for content in contents:
value = content.split()
img_path = image_path + value[0]
img = Image.open(img_path)
img_raw = img.tobytes()
labels = [0] * 10
labels[int(value[1])] = 1
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString())
num_pic += 1
print('the number of pidture :', num_pic)
writer.close()
print('writer tfrecord successful')
def generate_tfRecord():
isExists = os.path.exists(data_path)
if not isExists:
os.makedirs(data_path)
print('the directory was created successfully ')
else:
print('directory already exists')
def read_tfRecord(tfRecord_path):
filename_queue = tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features['imf_raw'], tf.uint8)
img.set_shape([784])
img = tf.cast(img, tf.float32) * (1./255)
label = tf.cast(features['label', tf.float32])
return img, label
def get_tfrecode(num, isTrain=True):
img, label = read_tfRecord(tfRecord_path)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=num,
num_threads=2,
capacity=1000,
min_after_dequeue=700
)
return img_batch, label_batch
def main():
generate_tfRecord()
if __name__ == '__main__':
main()
https://blog.csdn.net/qq_27825451/article/details/83301811
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)