首页 > 编程 > Python > 正文

TensorFlow数据输入的方法示例

2020-02-15 21:54:53
字体:
来源:转载
供稿:网友

读取数据(Reading data)

TensorFlow输入数据的方式有四种:

    tf.data API:可以很容易的构建一个复杂的输入通道(pipeline)(首选数据输入方式)(Eager模式必须使用该API来构建输入通道) Feeding:使用Python代码提供数据,然后将数据feeding到计算图中。 QueueRunner:基于队列的输入通道(在计算图计算前从队列中读取数据) Preloaded data:用一个constant常量将数据集加载到计算图中(主要用于小数据集)

1. tf.data API

关于tf.data.Dataset的更详尽解释请看《programmer's guide》。tf.data API能够从不同的输入或文件格式中读取、预处理数据,并且对数据应用一些变换(例如,batching、shuffling、mapping function over the dataset),tf.data API 是旧的 feeding、QueueRunner的升级。

2. Feeding

注意:Feeding是数据输入效率最低的方式,应该只用于小数据集和调试(debugging)

TensorFlow的Feeding机制允许我们将数据输入计算图中的任何一个Tensor。因此可以用Python来处理数据,然后直接将处理好的数据feed到计算图中 。

run()eval()中用feed_dict来将数据输入计算图:

with tf.Session(): input = tf.placeholder(tf.float32) classifier = ... print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))

虽然你可以用feed data替换任何Tensor的值(包括variables和constants),但最好的使用方法是使用一个tf.placeholder节点(专门用于feed数据)。它不用初始化,也不包含数据。一个placeholder没有被feed数据,则会报错。

使用placeholder和feed_dict的一个实例(数据集使用的是MNIST)见tensorflow/examples/tutorials/mnist/fully_connected_feed.py

3. QueueRunner

注意:这一部分介绍了基于队列(Queue)API构建输入通道(pipelines),这一方法完全可以使用 tf.data API来替代。

一个基于queue的从文件中读取records的通道(pipline)一般有以下几个步骤:

    文件名列表(The list of filenames) 文件名打乱(可选)(Optional filename shuffling) epoch限制(可选)(Optional epoch limit) 文件名队列(Filename queue) 与文件格式匹配的Reader(A Reader for the file format) decoder(A decoder for a record read by the reader) 预处理(可选)(Optional preprocessing) Example队列(Example queue)

3.1 Filenames, shuffling, and epoch limits

对于文件名列表,有很多方法:1. 使用一个constant string Tensor(比如:["file0", "file1"])或者 [("file%d" %i) for i in range(2)];2. 使用 tf.train.match_filenames_once 函数;3. 使用 tf.gfile.Glob(path_pattern)

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表