TensorFlow 读取CSV数据原理在此就不做详细介绍,直接通过代码实现:
方法一:
详细读取tf_read.csv 代码
#coding:utf-8 import tensorflow as tf filename_queue = tf.train.string_input_producer(["/home/yongcai/tf_read.csv"])reader = tf.TextLineReader()key, value = reader.read(filename_queue) record_defaults = [[1.], [1.], [1.], [1.]]col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3]) init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer() with tf.Session() as sess: sess.run(init_op) sess.run(local_init_op) # Start populating the filename queue. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: for i in range(30): example, label = sess.run([features, col4]) print(example) # print(label) except tf.errors.OutOfRangeError: print 'Done !!!' finally: coord.request_stop() coord.join(threads)
tf_read.csv 数据:
-0.76 15.67 -0.12 15.67-0.48 12.52 -0.06 12.511.33 9.11 0.12 9.1-0.88 20.35 -0.18 20.36-0.25 3.99 -0.01 3.99-0.87 26.25 -0.23 26.25-1.03 2.87 -0.03 2.87-0.51 7.81 -0.04 7.81-1.57 14.46 -0.23 14.46-0.1 10.02 -0.01 10.02-0.56 8.92 -0.05 8.92-1.2 4.1 -0.05 4.1-0.77 5.15 -0.04 5.15-0.88 4.48 -0.04 4.48-2.7 10.82 -0.3 10.82-1.23 2.4 -0.03 2.4-0.77 5.16 -0.04 5.15-0.81 6.15 -0.05 6.15-0.6 5.01 -0.03 5-1.25 4.75 -0.06 4.75-2.53 7.31 -0.19 7.3-1.15 16.39 -0.19 16.39-1.7 5.19 -0.09 5.18-0.62 3.23 -0.02 3.22-0.74 17.43 -0.13 17.41-0.77 15.41 -0.12 15.410 47 0 47.010.25 3.98 0.01 3.98-1.1 9.01 -0.1 9.01-1.02 3.87 -0.04 3.87
方法二:
详细读取 Iris_train.csv, Iris_test.csv 代码
#coding:utf-8 import tensorflow as tfimport os os.chdir("/home/yongcai/")print(os.getcwd()) def read_data(file_queue): reader = tf.TextLineReader(skip_header_lines=1) key, value = reader.read(file_queue) defaults = [[0], [0.], [0.], [0.], [0.], ['']] Id, SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm, Species = tf.decode_csv(value, defaults) preprocess_op = tf.case({ tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0), tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1), tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2), }, lambda: tf.constant(-1), exclusive=True) return tf.stack([SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm]), preprocess_op def create_pipeline(filename, batch_size, num_epochs=None): file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) example, label = read_data(file_queue) min_after_dequeue = 1000 capacity = min_after_dequeue + batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue ) return example_batch, label_batch # x_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000)x_test, y_test = create_pipeline('Iris-test.csv', 60) init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer()# output read data resultwith tf.Session() as sess: sess.run(init_op) sess.run(local_init_op) coord = tf.train.Coordinator() thread = tf.train.start_queue_runners(coord=coord) try: example, label = sess.run([x_test, y_test]) print example print label except tf.errors.OutOfRangeError: print 'Done !!!' finally: coord.request_stop() coord.join(threads=thread)
新闻热点
疑难解答