首页 > 编程 > Python > 正文

利用Tensorflow的队列多线程读取数据方式

2020-02-15 21:23:45
字体:
来源:转载
供稿:网友

在tensorflow中,有三种方式输入数据

1. 利用feed_dict送入numpy数组

2. 利用队列从文件中直接读取数据

3. 预加载数据

其中第一种方式很常用,在tensorflow的MNIST训练源码中可以看到,通过feed_dict={},可以将任意数据送入tensor中。

第二种方式相比于第一种,速度更快,可以利用多线程的优势把数据送入队列,再以batch的方式出队,并且在这个过程中可以很方便地对图像进行随机裁剪、翻转、改变对比度等预处理,同时可以选择是否对数据随机打乱,可以说是非常方便。该部分的源码在tensorflow官方的CIFAR-10训练源码中可以看到,但是对于刚学习tensorflow的人来说,比较难以理解,本篇博客就当成我调试完成后写的一篇总结,以防自己再忘记具体细节。

读取CIFAR-10数据集

按照第一种方式的话,CIFAR-10的读取只需要写一段非常简单的代码即可将测试集与训练集中的图像分别读取:

path = 'E:/Dataset/cifar-10/cifar-10-batches-py'# extract train examplesnum_train_examples = 50000x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')y_train = np.empty((num_train_examples), dtype='uint8')for i in range(1, 6):  fpath = os.path.join(path, 'data_batch_' + str(i))  (x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000])   = load_and_decode(fpath)# extract test examplesfpath = os.path.join(path, 'test_batch')x_test, y_test = load_and_decode(fpath)return x_train, y_train, x_test, np.array(y_test)

其中load_and_decode函数只需要按照CIFAR-10官网给出的方式decode就行,最终返回的x_train是一个[50000, 32, 32, 3]的ndarray,但对于ndarray来说,进行预处理就要麻烦很多,为了取mini-SGD的batch,还自己写了一个类,通过调用train_set.next_batch()函数来取,总而言之就是什么都要自己动手,效率确实不高

但对于第二种方式,读取起来就要麻烦很多,但使用起来,又快又方便

首先,把CIFAR-10的测试集文件读取出来,生成文件名列表

path = 'E:/Dataset/cifar-10/cifar-10-batches-py'filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]

有了列表以后,利用tf.train.string_input_producer函数生成一个读取队列

filename_queue = tf.train.string_input_producer(filenames)

接下来,我们调用read_cifar10函数,得到一幅一幅的图像,该函数的代码如下:

def read_cifar10(filename_queue): label_bytes = 1 IMAGE_SIZE = 32 CHANNELS = 3 image_bytes = IMAGE_SIZE*IMAGE_SIZE*3 record_bytes = label_bytes+image_bytes # define a reader reader = tf.FixedLengthRecordReader(record_bytes) key, value = reader.read(filename_queue) record_bytes = tf.decode_raw(value, tf.uint8) label = tf.strided_slice(record_bytes, [0], [label_bytes]) depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],              [label_bytes + image_bytes]),        [CHANNELS, IMAGE_SIZE, IMAGE_SIZE]) image = tf.transpose(depth_major, [1, 2, 0]) return image, label            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表