首页 > 编程 > Python > 正文

Tensorflow 多线程与多进程数据加载实例

2020-02-15 21:23:38
字体:
来源:转载
供稿:网友

在项目中遇到需要处理超级大量的数据集,无法载入内存的问题就不用说了,单线程分批读取和处理(虽然这个处理也只是特别简单的首尾相连的操作)也会使瓶颈出现在CPU性能上,所以研究了一下多线程和多进程的数据读取和预处理,都是通过调用dataset api实现

1. 多线程数据读取

第一种方法是可以直接从csv里读取数据,但返回值是tensor,需要在sess里run一下才能返回真实值,无法实现真正的并行处理,但如果直接用csv文件或其他什么文件存了特征值,可以直接读取后进行训练,可使用这种方法.

import tensorflow as tf#这里是返回的数据类型,具体内容无所谓,类型对应就好了,比如我这个,就是一个四维的向量,前三维是字符串类型 最后一维是int类型record_defaults = [[""], [""], [""], [0]]def decode_csv(line): parsed_line = tf.decode_csv(line, record_defaults) label = parsed_line[-1]  # label  del parsed_line[-1]   # delete the last element from the list features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc. #label = tf.stack(label)   #NOT needed. Only if more than 1 column makes the label... batch_to_return = features, label return batch_to_returnfilenames = tf.placeholder(tf.string, shape=[None])dataset5 = tf.data.Dataset.from_tensor_slices(filenames)#在这里设置线程数目dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv,num_parallel_calls=15)) dataset5 = dataset5.shuffle(buffer_size=1000)dataset5 = dataset5.batch(32) #batch_sizeiterator5 = dataset5.make_initializable_iterator()next_element5 = iterator5.get_next()#这里是需要加载的文件名training_filenames = ["train.csv"]validation_filenames = ["vali.csv"]with tf.Session() as sess: for _ in range(2):   	#通过文件名初始化迭代器  sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})  while True:   try:   #这里获得真实值    features, labels = sess.run(next_element5)    # Train...   # print("(train) features: ")   # print(features)   # print("(train) labels: ")   # print(labels)    except tf.errors.OutOfRangeError:    print("Out of range error triggered (looped through training set 1 time)")    break # Validate (cost, accuracy) on train set print("/nDone with the first iterator/n") sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames}) while True:  try:   features, labels = sess.run(next_element5)   # Validate (cost, accuracy) on dev set  # print("(dev) features: ")  # print(features)  # print("(dev) labels: ")  # print(labels)  except tf.errors.OutOfRangeError:   print("Out of range error triggered (looped through dev set 1 time only)")   break 

第二种方法,基于生成器,可以进行预处理操作了,sess里run出来的结果可以直接进行输入训练,但需要自己写一个生成器,我使用的测试代码如下:

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表