Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取


Posted in Python onJune 30, 2020

单一数据读取方式:

第一种:slice_input_producer()

# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]
[images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

第二种:string_input_producer()

# 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看
file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True)

reader = tf.WholeFileReader()      # 定义文件读取器
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容

!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

!!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

!!!以上两种方法都可以生成文件名队列。

(随机)批量数据读取方式:

batchsize=2# 每次读取的样本数量
tf.train.batch(tensors, batch_size=batchsize)
tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

 TFRecord文件的打包与读取

 一、单一数据读取方式

第一种:slice_input_producer()

def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

案例1:

import tensorflow as tf

images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
labels = [1, 2, 3, 4]

# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

# 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)

data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
print(type(data))  # <class 'list'>

with tf.Session() as sess:
  # sess.run(tf.local_variables_initializer())
  sess.run(tf.local_variables_initializer())
  coord = tf.train.Coordinator() # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器

  for i in range(10):
    print(sess.run(data))

  coord.request_stop()
  coord.join(threads)

"""

运行结果:

[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image3.jpg', 3]
"""

!!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

!!!num_epochs设置

 第二种:string_input_producer()

def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)

文件读取器

不同类型的文件对应不同的文件读取器,我们称为 reader对象;

该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;

reader = tf.TextLineReader()   ### 一行一行读取,适用于所有文本文件

reader = tf.TFRecordReader()   ### A Reader that outputs the records from a TFRecords file

reader = tf.WholeFileReader()   ### 一次读取整个文件,适用图片

案例2:读取csv文件

import tensorflow as tf

filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']

file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2)  # 生成文件名队列
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
# reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容
print(type(file_queue))

init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  try:
    while not coord.should_stop():
      for i in range(6):
        print(sess.run([key, value]))
      break
  except tf.errors.OutOfRangeError:
    print('read done')
  finally:
    coord.request_stop()
  coord.join(threads)

"""
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
运行结果:
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
"""
"""
reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
运行结果:
[b'data/B.csv:1', b'4.jpg,4']
[b'data/B.csv:2', b'5.jpg,5']
[b'data/B.csv:3', b'6.jpg,6']
[b'data/C.csv:1', b'7.jpg,7']
[b'data/C.csv:2', b'8.jpg,8']
[b'data/C.csv:3', b'9.jpg,9']
"""

案例3:读取图片(每次读取全部图片内容,不是一行一行)

import tensorflow as tf

filename = ['1.jpg', '2.jpg']
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
reader = tf.WholeFileReader()       # 文件读取器
key, value = reader.read(filename_queue)  # 读取文件 key:文件名;value:图片数据,bytes

with tf.Session() as sess:
  tf.local_variables_initializer().run()
  coord = tf.train.Coordinator()   # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord)

  for i in range(filename.__len__()):
    image_data = sess.run(value)
    with open('img_%d.jpg' % i, 'wb') as f:
      f.write(image_data)
  coord.request_stop()
  coord.join(threads)

 二、(随机)批量数据读取方式:

功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;

案例4:slice_input_producer() 与 batch()

import tensorflow as tf
import numpy as np

images = np.arange(20).reshape([10, 2])
label = np.asarray(range(0, 10))
images = tf.cast(images, tf.float32)# 可以注释掉,不影响运行结果
label = tf.cast(label, tf.int32)

 # 可以注释掉,不影响运行结果

batchsize = 6  # 每次获取元素的数量
input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize)

# 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
# image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10)

with tf.Session() as sess:
  coord = tf.train.Coordinator()   # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
  for cnt in range(2):
    print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize))
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    print(image_batch_v, label_batch_v, label_batch_v.__len__())

  coord.request_stop()
  coord.join(threads)

"""

运行结果:
第1次获取数据,每次batch=6...
[[ 0.  1.]
 [ 2.  3.]
 [ 4.  5.]
 [ 6.  7.]
 [ 8.  9.]
 [10. 11.]] [0 1 2 3 4 5] 6
第2次获取数据,每次batch=6...
[[12. 13.]
 [14. 15.]
 [16. 17.]
 [18. 19.]
 [ 0.  1.]
 [ 2.  3.]] [6 7 8 9 0 1] 6
"""

 案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

import tensorflow as tf
 import glob
 import cv2 as cv
 
 def read_imgs(filename, picture_format, input_image_shape, batch_size=):
   """
   从本地批量的读取图片
   :param filename: 图片路径(包括图片的文件名),[]
   :param picture_format: 图片的格式,如 bmp,jpg,png等; string
   :param input_image_shape: 输入图像的大小; (h,w,c)或[]
   :param batch_size: 每次从文件队列中加载图片的数量; int
   :return: batch_size张图片数据, Tensor
   """
   global new_img
   # 创建文件队列
   file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True)
   # 创建文件读取器
   reader = tf.WholeFileReader()
   # 读取文件队列中的文件
   _, img_bytes = reader.read(file_queue)
   # print(img_bytes)  # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
   # 对图片进行解码
   if picture_format == ".bmp":
     new_img = tf.image.decode_bmp(img_bytes, channels=1)
   elif picture_format == ".jpg":
     new_img = tf.image.decode_jpeg(img_bytes, channels=3)
   else:
     pass
   # 重新设置图片的大小
   # new_img = tf.image.resize_images(new_img, input_image_shape)
   new_img = tf.reshape(new_img, input_image_shape)
   # 设置图片的数据类型
   new_img = tf.image.convert_image_dtype(new_img, tf.uint)
 
   # return new_img
   return tf.train.batch([new_img], batch_size)
 
 
 def main():
   image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
   image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
   print(type(image_batch))
   # image_path = glob.glob(r'.\*.jpg')
   # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1)
 
   sess = tf.Session()
   sess.run(tf.local_variables_initializer())
   tf.train.start_queue_runners(sess=sess)
 
   image_batch = sess.run(image_batch)
   print(type(image_batch))  # <class 'numpy.ndarray'>
 
   for i in range(image_batch.__len__()):
     cv.imshow("win_"+str(i), image_batch[i])
   cv.waitKey()
   cv.destroyAllWindows()
 
 def start():
   image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
   image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
   print(type(image_batch))  # <class 'tensorflow.python.framework.ops.Tensor'>
 
 
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()   # 线程的协调器
     threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
     image_batch = sess.run(image_batch)
     print(type(image_batch))  # <class 'numpy.ndarray'>
 
     for i in range(image_batch.__len__()):
       cv.imshow("win_"+str(i), image_batch[i])
     cv.waitKey()
     cv.destroyAllWindows()
 
     # 若使用 with 方式打开 Session,且没加如下行语句,则会出错
     # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
     # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
     coord.request_stop()
     coord.join(threads)
 
 
 if __name__ == "__main__":
   # main()
   start()

案列6:TFRecord文件打包与读取

 TFRecord文件打包案列

def write_TFRecord(filename, data, labels, is_shuffler=True):
   """
   将数据打包成TFRecord格式
   :param filename: 打包后路径名,默认在工程目录下创建该文件;String
   :param data: 需要打包的文件路径名;list
   :param labels: 对应文件的标签;list
   :param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
   :return: None
   """
   im_data = list(data)
   im_labels = list(labels)
 
   index = [i for i in range(im_data.__len__())]
   if is_shuffler:
     np.random.shuffle(index)
 
   # 创建写入器,然后使用该对象写入样本example
   writer = tf.python_io.TFRecordWriter(filename)
   for i in range(im_data.__len__()):
     im_d = im_data[index[i]]  # im_d:存放着第index[i]张图片的路径信息
     im_l = im_labels[index[i]] # im_l:存放着对应图片的标签信息
 
     # # 获取当前的图片数据 方式一:
     # data = cv2.imread(im_d)
     # # 创建样本
     # ex = tf.train.Example(
     #   features=tf.train.Features(
     #     feature={
     #       "image": tf.train.Feature(
     #         bytes_list=tf.train.BytesList(
     #           value=[data.tobytes()])), # 需要打包成bytes类型
     #       "label": tf.train.Feature(
     #         int64_list=tf.train.Int64List(
     #           value=[im_l])),
     #     }
     #   )
     # )
     # 获取当前的图片数据 方式二:相对于方式一,打包文件占用空间小了一半多
     data = tf.gfile.FastGFile(im_d, "rb").read()
     ex = tf.train.Example(
       features=tf.train.Features(
         feature={
           "image": tf.train.Feature(
             bytes_list=tf.train.BytesList(
               value=[data])), # 此时的data已经是bytes类型
           "label": tf.train.Feature(
             int_list=tf.train.IntList(
               value=[im_l])),
         }
       )
     )
 
     # 写入将序列化之后的样本
     writer.write(ex.SerializeToString())
   # 关闭写入器
   writer.close()

TFReord文件的读取案列

import tensorflow as tf
 import cv2
 
 def read_TFRecord(file_list, batch_size=):
   """
   读取TFRecord文件
   :param file_list: 存放TFRecord的文件名,List
   :param batch_size: 每次读取图片的数量
   :return: 解析后图片及对应的标签
   """
   file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True)
   reader = tf.TFRecordReader()
   _, ex = reader.read(file_queue)
   batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5)
 
   feature = {
     'image': tf.FixedLenFeature([], tf.string),
     'label': tf.FixedLenFeature([], tf.int64)
   }
   example = tf.parse_example(batch, features=feature)
 
   images = tf.decode_raw(example['image'], tf.uint)
   images = tf.reshape(images, [-1, 32, 32, 3])
 
   return images, example['label']
 
 
 
 def main():
   # filelist = ['data/train.tfrecord']
   filelist = ['data/test.tfrecord']
   images, labels = read_TFRecord(filelist, 2)
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 
     try:
       while not coord.should_stop():
         for i in range():
           image_bth, _ = sess.run([images, labels])
           print(_)
 
           cv2.imshow("image_0", image_bth[0])
           cv2.imshow("image_1", image_bth[1])
         break
     except tf.errors.OutOfRangeError:
       print('read done')
     finally:
       coord.request_stop()
     coord.join(threads)
     cv2.waitKey(0)
     cv2.destroyAllWindows()
 
 if __name__ == "__main__":
   main()

到此这篇关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的文章就介绍到这了,更多相关Tensorflow TFRecord打包与读取内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python进阶教程之文本文件的读取和写入
Aug 29 Python
python基于右递归解决八皇后问题的方法
May 25 Python
一波神奇的Python语句、函数与方法的使用技巧总结
Dec 08 Python
python 列表降维的实例讲解
Jun 28 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
python如何实现视频转代码视频
Jun 17 Python
使用Python刷淘宝喵币(低阶入门版)
Oct 30 Python
Python中的X[:,0]、X[:,1]、X[:,:,0]、X[:,:,1]、X[:,m:n]和X[:,:,m:n]
Feb 13 Python
python图形开发GUI库wxpython使用方法详解
Feb 14 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
Mar 11 Python
python 实现汉诺塔游戏
Nov 28 Python
解析python中的jsonpath 提取器
Jan 18 Python
使用Tensorflow-GPU禁用GPU设置(CPU与GPU速度对比)
Jun 30 #Python
keras的backend 设置 tensorflow,theano操作
Jun 30 #Python
浅谈TensorFlow中读取图像数据的三种方式
Jun 30 #Python
python中 _、__、__xx__()区别及使用场景
Jun 30 #Python
Django实现内容缓存实例方法
Jun 30 #Python
Pytorch 卷积中的 Input Shape用法
Jun 29 #Python
Python闭包装饰器使用方法汇总
Jun 29 #Python
You might like
第一节--面向对象编程
2006/11/16 PHP
php语言注释,单行注释和多行注释
2018/01/21 PHP
showModelessDialog()使用详解
2006/09/07 Javascript
JavaScript 类型的包装对象(Typed Wrappers)
2011/10/27 Javascript
javascript 中that的含义示例介绍
2014/05/14 Javascript
SuperSlide2实现图片滚动特效
2014/06/20 Javascript
判断在css加载完毕后执行后续代码示例
2014/09/03 Javascript
js防止页面被iframe调用的方法
2014/10/30 Javascript
利用js实现禁止复制文本信息
2015/06/03 Javascript
JavaScript实现数组降维详解
2017/01/05 Javascript
通过npm引用的vue组件使用详解
2017/03/02 Javascript
node.js-v6新版安装具体步骤(分享)
2017/09/06 Javascript
vue多种弹框的弹出形式的示例代码
2017/09/18 Javascript
JS路由跳转的简单实现代码
2017/09/21 Javascript
Vue项目添加动态浏览器头部title的方法
2018/07/11 Javascript
Vue.js组件使用props传递数据的方法
2019/10/19 Javascript
js实现ajax的用户简单登入功能
2020/06/18 Javascript
[00:12]DAC2018 天才少年转战三号位,他的SOLO是否仍如昔日般强大?
2018/04/06 DOTA
使用Python实现一个简单的项目监控
2015/03/31 Python
收藏整理的一些Python常用方法和技巧
2015/05/18 Python
Python实现给qq邮箱发送邮件的方法
2015/05/28 Python
用python简单实现mysql数据同步到ElasticSearch的教程
2018/05/30 Python
django用户登录和注销的实现方法
2018/07/16 Python
python实现机器学习之元线性回归
2018/09/06 Python
windows下安装Python虚拟环境virtualenvwrapper-win
2019/06/14 Python
程序员的七夕用30行代码让Python化身表白神器
2019/08/07 Python
Python实现ATM系统
2020/02/17 Python
提高EJB性能都有哪些技巧
2012/03/25 面试题
排序都有哪几种方法?请列举。用JAVA实现一个快速排序
2014/02/16 面试题
铁路工务反思材料
2014/02/07 职场文书
优秀的导游求职信范文
2014/04/06 职场文书
县级文明单位申报材料
2014/05/23 职场文书
525心理活动总结
2014/07/04 职场文书
雷锋电影观后感
2015/06/10 职场文书
党组织关系的介绍信模板
2019/06/21 职场文书
创业计划书之都市休闲农庄
2019/12/28 职场文书