首页 > 编程 > Python > 正文

基于Tensorflow批量数据的输入实现方式

2020-02-15 21:23:05
字体:
来源:转载
供稿:网友

基于Tensorflow下的批量数据的输入处理:

1.Tensor TFrecords格式

2.h5py的库的数组方法

在tensorflow的框架下写CNN代码,我在书写过程中,感觉不是框架内容难写, 更多的是我在对图像的预处理和输入这部分花了很多精神。

使用了两种方法:

方法一:

Tensor 以Tfrecords的格式存储数据,如果对数据进行标签,可以同时做到数据打标签。

①创建TFrecords文件

orig_image = '/home/images/train_image/'gen_image = '/home/images/image_train.tfrecords'def create_record():  writer = tf.python_io.TFRecordWriter(gen_image)  class_path = orig_image  for img_name in os.listdir(class_path): #读取每一幅图像    img_path = class_path + img_name     img = Image.open(img_path) #读取图像    #img = img.resize((256, 256)) #设置图片大小, 在这里可以对图像进行处理    img_raw = img.tobytes() #将图片转化为原声bytes     example = tf.train.Example(         features=tf.train.Features(feature={             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打标签             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存储数据             }))    writer.write(example.SerializeToString())  writer.close()

②读取TFrecords文件

def read_and_decode(filename):  #创建文件队列,不限读取的数据  filename_queue = tf.train.string_input_producer([filename])  reader = tf.TFRecordReader()  _, serialized_example = reader.read(filename_queue)  features = tf.parse_single_example(      serialized_example,      features={          'label': tf.FixedLenFeature([], tf.int64),          'img_raw': tf.FixedLenFeature([], tf.string)})  label = features['label']  img = features['img_raw']  img = tf.decode_raw(img, tf.uint8) #tf.float32  img = tf.image.convert_image_dtype(img, dtype=tf.float32)  img = tf.reshape(img, [256, 256, 1])  label = tf.cast(label, tf.int32)  return img, label

③批量读取数据,使用tf.train.batch

min_after_dequeue = 10000capacity = min_after_dequeue + 3 * batch_sizenum_samples= len(os.listdir(orig_image))create_record()img, label = read_and_decode(gen_image)total_batch = int(num_samples/batch_size)image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size,                      num_threads=32, capacity=capacity) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())with tf.Session() as sess:  sess.run(init_op)  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(total_batch):     cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch])  coord.request_stop()  coord.join(threads)

方法二:

使用h5py就是使用数组的格式来存储数据

这个方法比较好,在CNN的过程中,会使用到多个数据类存储,比较好用, 比如一个数据进行了两种以上的变化,并且分类存储,我认为这个方法会比较好用。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表