数据集加载
1.keas.datasets
tensoflow.keras提供了keras.datasets的接口
常见的数据集:
Boston housing price regerssion dataset
MNIST/Fashion-MNIST dataset
sentiment classification dataset(imdb)
small images classidication dataset(CIFAR10/100)
数据集加载步骤
Step0: 准备要加载的numpy数据
Step1: 使用 tf.data.Dataset.from_tensor_slices() 函数进行加载
Step2: 使用 shuffle() 打乱数据
Step3: 使用 map() 函数进行预处理
Step4: 使用 batch() 函数设置 batch size 值
Step5: 根据需要 使用 repeat() 设置是否循环迭代数据集
MNIST
![](https://img-blog.csdnimg.cn/20201021184838341.png)
keras.datasets.mnist.load_data()
将MNIST数据集加载并处理成Numpy格式。
(x,y)是60000张训练数据集,
(x_test,y_test)是10000张测试数据集
其中y和y_test存储的是0~9的数字,代表每张图片的值。
y[:4]=[5,0,4,1]表示前四张图片的值分别为5,0,4,1。
tf.one_hot( )将y的数据转换成one_hot类型。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021185208869.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2RvbmdjamF2YQ==,size_16,color_FFFFFF,t_70#pic_center)
CIFAR10/100
10和100表示数据的类别,一共10个大类,每一个大类分成10个小类,共有100类。
图片的size=[32,32,3]数据很小。
![](https://img-blog.csdnimg.cn/20201021190948783.png)
共有60000张图片,其中50000张是训练数据集,10000张是测试数据集。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021191552786.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2RvbmdjamF2YQ==,size_16,color_FFFFFF,t_70#pic_center)
tf.data.Dataset.from_tensor_slices()
切分传入的 Tensor 的第一个维度,生成相应的 dataset
使用迭代器进行迭代
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021191920264.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2RvbmdjamF2YQ==,size_16,color_FFFFFF,t_70#pic_center)
.shuffle
将数据集打散
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021193504871.png#pic_center)
.map
数据预处理
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021193609873.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2RvbmdjamF2YQ==,size_16,color_FFFFFF,t_70#pic_center)
.batch
设置batch_size,下图的db2.batch(32),batch_size=32,以32张图片分为一组。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20201021194618120.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2RvbmdjamF2YQ==,size_16,color_FFFFFF,t_70#pic_center)
.repeat
设置数据集的迭代次数,空着则一致循环。
![](https://img-blog.csdnimg.cn/20201021195210950.png)
For Example:
def preparer_mnist_featues_and_lables(x, y):
x = tf.cast(x, tf.float32) / 255. #缩放到0~1范围
y = tf.cast(y, tf.int64)
retu x,y
def mnist_dataset():
(x, y), (x_val, y_val) = datasets.fanshion_muist.load_data() #加载数据集
# x:60k y:10k x_val:60k y_val:10k
y = tf.one_hot(y, depth = 10) #y:[10k,10]
y_val = tf.one_hot(y_val, depth = 10) #y_val:[10k,10]
ds = tf.data.Dataset.from_tenso_slices((x, y))
ds = ds.map(repae_mnist_featues_and_lables) #预处理
ds = ds.shuffle(60000).batch(100) #打散在batch
ds_val = tf.data.Dataset.from_tenso_slices((x_val, y_val))
ds_val = ds_val.map(repae_mnist_featues_and_lables)
ds_val = ds_val.shuffle(10000).batch(100)
return ds, ds_val