首页 > 编程 > Python > 正文

TensorFlow 读取CSV数据的实例

2020-02-15 21:22:37
字体:
来源:转载
供稿:网友

TensorFlow 读取CSV数据原理在此就不做详细介绍,直接通过代码实现:

方法一:

详细读取tf_read.csv 代码

#coding:utf-8 import tensorflow as tf filename_queue = tf.train.string_input_producer(["/home/yongcai/tf_read.csv"])reader = tf.TextLineReader()key, value = reader.read(filename_queue) record_defaults = [[1.], [1.], [1.], [1.]]col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3]) init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer() with tf.Session() as sess: sess.run(init_op) sess.run(local_init_op)  # Start populating the filename queue. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord)  try:  for i in range(30):   example, label = sess.run([features, col4])   print(example)   # print(label) except tf.errors.OutOfRangeError:  print 'Done !!!'  finally:  coord.request_stop()  coord.join(threads)

tf_read.csv 数据:

-0.76	15.67	-0.12	15.67-0.48	12.52	-0.06	12.511.33	9.11	0.12	9.1-0.88	20.35	-0.18	20.36-0.25	3.99	-0.01	3.99-0.87	26.25	-0.23	26.25-1.03	2.87	-0.03	2.87-0.51	7.81	-0.04	7.81-1.57	14.46	-0.23	14.46-0.1	10.02	-0.01	10.02-0.56	8.92	-0.05	8.92-1.2	4.1	-0.05	4.1-0.77	5.15	-0.04	5.15-0.88	4.48	-0.04	4.48-2.7	10.82	-0.3	10.82-1.23	2.4	-0.03	2.4-0.77	5.16	-0.04	5.15-0.81	6.15	-0.05	6.15-0.6	5.01	-0.03	5-1.25	4.75	-0.06	4.75-2.53	7.31	-0.19	7.3-1.15	16.39	-0.19	16.39-1.7	5.19	-0.09	5.18-0.62	3.23	-0.02	3.22-0.74	17.43	-0.13	17.41-0.77	15.41	-0.12	15.410	47	0	47.010.25	3.98	0.01	3.98-1.1	9.01	-0.1	9.01-1.02	3.87	-0.04	3.87

方法二:

详细读取 Iris_train.csv, Iris_test.csv 代码

#coding:utf-8 import tensorflow as tfimport os os.chdir("/home/yongcai/")print(os.getcwd())  def read_data(file_queue): reader = tf.TextLineReader(skip_header_lines=1) key, value = reader.read(file_queue) defaults = [[0], [0.], [0.], [0.], [0.], ['']] Id, SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm, Species = tf.decode_csv(value, defaults)  preprocess_op = tf.case({  tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0),  tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1),  tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2), }, lambda: tf.constant(-1), exclusive=True)  return tf.stack([SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm]), preprocess_op  def create_pipeline(filename, batch_size, num_epochs=None): file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) example, label = read_data(file_queue)  min_after_dequeue = 1000 capacity = min_after_dequeue + batch_size example_batch, label_batch = tf.train.shuffle_batch(  [example, label], batch_size=batch_size, capacity=capacity,  min_after_dequeue=min_after_dequeue )  return example_batch, label_batch  # x_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000)x_test, y_test = create_pipeline('Iris-test.csv', 60) init_op = tf.global_variables_initializer()local_init_op = tf.local_variables_initializer()# output read data resultwith tf.Session() as sess: sess.run(init_op) sess.run(local_init_op) coord = tf.train.Coordinator() thread = tf.train.start_queue_runners(coord=coord)  try:   example, label = sess.run([x_test, y_test])  print example  print label  except tf.errors.OutOfRangeError:  print 'Done !!!'  finally:  coord.request_stop()  coord.join(threads=thread)            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表