最近在学习TensorFlow,比较烦人的是使用tensorflow.examples.tutorials.mnist.input_data
读取数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/temp/mnist_data/')
X = mnist.test.images.reshape(-1, n_steps, n_inputs)
y = mnist.test.labels
时,经常出现网络连接错误
解决方法其实很简单,这里我们可以看一下input_data.py
的源代码(这里截取关键部分)
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
可以看到,代码会先检查文件是否存在,如果不存在再进行下载,那么我是不是自己下载数据不就行了?
MNIST的数据集是从Yann LeCun教授的官网下载,下载完成之后修改一下我们读取数据的代码,加上我们下载的路径即可
from tensorflow.examples.tutorials.mnist import input_data
import os
data_path = os.path.join('.', 'temp', 'data')
mnist = input_data.read_data_sets(datapath)
X = mnist.test.images.reshape(-1, n_steps, n_inputs)
y = mnist.test.labels
测试一下
成功!