Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取


Posted in Python onJune 30, 2020

单一数据读取方式:

第一种:slice_input_producer()

# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]
[images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

第二种:string_input_producer()

# 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看
file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True)

reader = tf.WholeFileReader()      # 定义文件读取器
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容

!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

!!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

!!!以上两种方法都可以生成文件名队列。

(随机)批量数据读取方式:

batchsize=2# 每次读取的样本数量
tf.train.batch(tensors, batch_size=batchsize)
tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

 TFRecord文件的打包与读取

 一、单一数据读取方式

第一种:slice_input_producer()

def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

案例1:

import tensorflow as tf

images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
labels = [1, 2, 3, 4]

# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

# 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)

data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
print(type(data))  # <class 'list'>

with tf.Session() as sess:
  # sess.run(tf.local_variables_initializer())
  sess.run(tf.local_variables_initializer())
  coord = tf.train.Coordinator() # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器

  for i in range(10):
    print(sess.run(data))

  coord.request_stop()
  coord.join(threads)

"""

运行结果:

[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image3.jpg', 3]
"""

!!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

!!!num_epochs设置

 第二种:string_input_producer()

def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)

文件读取器

不同类型的文件对应不同的文件读取器,我们称为 reader对象;

该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;

reader = tf.TextLineReader()   ### 一行一行读取,适用于所有文本文件

reader = tf.TFRecordReader()   ### A Reader that outputs the records from a TFRecords file

reader = tf.WholeFileReader()   ### 一次读取整个文件,适用图片

案例2:读取csv文件

import tensorflow as tf

filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']

file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2)  # 生成文件名队列
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
# reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容
print(type(file_queue))

init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
with tf.Session() as sess:
  sess.run(init)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  try:
    while not coord.should_stop():
      for i in range(6):
        print(sess.run([key, value]))
      break
  except tf.errors.OutOfRangeError:
    print('read done')
  finally:
    coord.request_stop()
  coord.join(threads)

"""
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
运行结果:
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
"""
"""
reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
运行结果:
[b'data/B.csv:1', b'4.jpg,4']
[b'data/B.csv:2', b'5.jpg,5']
[b'data/B.csv:3', b'6.jpg,6']
[b'data/C.csv:1', b'7.jpg,7']
[b'data/C.csv:2', b'8.jpg,8']
[b'data/C.csv:3', b'9.jpg,9']
"""

案例3:读取图片(每次读取全部图片内容,不是一行一行)

import tensorflow as tf

filename = ['1.jpg', '2.jpg']
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
reader = tf.WholeFileReader()       # 文件读取器
key, value = reader.read(filename_queue)  # 读取文件 key:文件名;value:图片数据,bytes

with tf.Session() as sess:
  tf.local_variables_initializer().run()
  coord = tf.train.Coordinator()   # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord)

  for i in range(filename.__len__()):
    image_data = sess.run(value)
    with open('img_%d.jpg' % i, 'wb') as f:
      f.write(image_data)
  coord.request_stop()
  coord.join(threads)

 二、(随机)批量数据读取方式:

功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;

案例4:slice_input_producer() 与 batch()

import tensorflow as tf
import numpy as np

images = np.arange(20).reshape([10, 2])
label = np.asarray(range(0, 10))
images = tf.cast(images, tf.float32)# 可以注释掉,不影响运行结果
label = tf.cast(label, tf.int32)

 # 可以注释掉,不影响运行结果

batchsize = 6  # 每次获取元素的数量
input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize)

# 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
# image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10)

with tf.Session() as sess:
  coord = tf.train.Coordinator()   # 线程的协调器
  threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
  for cnt in range(2):
    print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize))
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    print(image_batch_v, label_batch_v, label_batch_v.__len__())

  coord.request_stop()
  coord.join(threads)

"""

运行结果:
第1次获取数据,每次batch=6...
[[ 0.  1.]
 [ 2.  3.]
 [ 4.  5.]
 [ 6.  7.]
 [ 8.  9.]
 [10. 11.]] [0 1 2 3 4 5] 6
第2次获取数据,每次batch=6...
[[12. 13.]
 [14. 15.]
 [16. 17.]
 [18. 19.]
 [ 0.  1.]
 [ 2.  3.]] [6 7 8 9 0 1] 6
"""

 案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

import tensorflow as tf
 import glob
 import cv2 as cv
 
 def read_imgs(filename, picture_format, input_image_shape, batch_size=):
   """
   从本地批量的读取图片
   :param filename: 图片路径(包括图片的文件名),[]
   :param picture_format: 图片的格式,如 bmp,jpg,png等; string
   :param input_image_shape: 输入图像的大小; (h,w,c)或[]
   :param batch_size: 每次从文件队列中加载图片的数量; int
   :return: batch_size张图片数据, Tensor
   """
   global new_img
   # 创建文件队列
   file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True)
   # 创建文件读取器
   reader = tf.WholeFileReader()
   # 读取文件队列中的文件
   _, img_bytes = reader.read(file_queue)
   # print(img_bytes)  # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
   # 对图片进行解码
   if picture_format == ".bmp":
     new_img = tf.image.decode_bmp(img_bytes, channels=1)
   elif picture_format == ".jpg":
     new_img = tf.image.decode_jpeg(img_bytes, channels=3)
   else:
     pass
   # 重新设置图片的大小
   # new_img = tf.image.resize_images(new_img, input_image_shape)
   new_img = tf.reshape(new_img, input_image_shape)
   # 设置图片的数据类型
   new_img = tf.image.convert_image_dtype(new_img, tf.uint)
 
   # return new_img
   return tf.train.batch([new_img], batch_size)
 
 
 def main():
   image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
   image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
   print(type(image_batch))
   # image_path = glob.glob(r'.\*.jpg')
   # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1)
 
   sess = tf.Session()
   sess.run(tf.local_variables_initializer())
   tf.train.start_queue_runners(sess=sess)
 
   image_batch = sess.run(image_batch)
   print(type(image_batch))  # <class 'numpy.ndarray'>
 
   for i in range(image_batch.__len__()):
     cv.imshow("win_"+str(i), image_batch[i])
   cv.waitKey()
   cv.destroyAllWindows()
 
 def start():
   image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
   image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
   print(type(image_batch))  # <class 'tensorflow.python.framework.ops.Tensor'>
 
 
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()   # 线程的协调器
     threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
     image_batch = sess.run(image_batch)
     print(type(image_batch))  # <class 'numpy.ndarray'>
 
     for i in range(image_batch.__len__()):
       cv.imshow("win_"+str(i), image_batch[i])
     cv.waitKey()
     cv.destroyAllWindows()
 
     # 若使用 with 方式打开 Session,且没加如下行语句,则会出错
     # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
     # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
     coord.request_stop()
     coord.join(threads)
 
 
 if __name__ == "__main__":
   # main()
   start()

案列6:TFRecord文件打包与读取

 TFRecord文件打包案列

def write_TFRecord(filename, data, labels, is_shuffler=True):
   """
   将数据打包成TFRecord格式
   :param filename: 打包后路径名,默认在工程目录下创建该文件;String
   :param data: 需要打包的文件路径名;list
   :param labels: 对应文件的标签;list
   :param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
   :return: None
   """
   im_data = list(data)
   im_labels = list(labels)
 
   index = [i for i in range(im_data.__len__())]
   if is_shuffler:
     np.random.shuffle(index)
 
   # 创建写入器,然后使用该对象写入样本example
   writer = tf.python_io.TFRecordWriter(filename)
   for i in range(im_data.__len__()):
     im_d = im_data[index[i]]  # im_d:存放着第index[i]张图片的路径信息
     im_l = im_labels[index[i]] # im_l:存放着对应图片的标签信息
 
     # # 获取当前的图片数据 方式一:
     # data = cv2.imread(im_d)
     # # 创建样本
     # ex = tf.train.Example(
     #   features=tf.train.Features(
     #     feature={
     #       "image": tf.train.Feature(
     #         bytes_list=tf.train.BytesList(
     #           value=[data.tobytes()])), # 需要打包成bytes类型
     #       "label": tf.train.Feature(
     #         int64_list=tf.train.Int64List(
     #           value=[im_l])),
     #     }
     #   )
     # )
     # 获取当前的图片数据 方式二:相对于方式一,打包文件占用空间小了一半多
     data = tf.gfile.FastGFile(im_d, "rb").read()
     ex = tf.train.Example(
       features=tf.train.Features(
         feature={
           "image": tf.train.Feature(
             bytes_list=tf.train.BytesList(
               value=[data])), # 此时的data已经是bytes类型
           "label": tf.train.Feature(
             int_list=tf.train.IntList(
               value=[im_l])),
         }
       )
     )
 
     # 写入将序列化之后的样本
     writer.write(ex.SerializeToString())
   # 关闭写入器
   writer.close()

TFReord文件的读取案列

import tensorflow as tf
 import cv2
 
 def read_TFRecord(file_list, batch_size=):
   """
   读取TFRecord文件
   :param file_list: 存放TFRecord的文件名,List
   :param batch_size: 每次读取图片的数量
   :return: 解析后图片及对应的标签
   """
   file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True)
   reader = tf.TFRecordReader()
   _, ex = reader.read(file_queue)
   batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5)
 
   feature = {
     'image': tf.FixedLenFeature([], tf.string),
     'label': tf.FixedLenFeature([], tf.int64)
   }
   example = tf.parse_example(batch, features=feature)
 
   images = tf.decode_raw(example['image'], tf.uint)
   images = tf.reshape(images, [-1, 32, 32, 3])
 
   return images, example['label']
 
 
 
 def main():
   # filelist = ['data/train.tfrecord']
   filelist = ['data/test.tfrecord']
   images, labels = read_TFRecord(filelist, 2)
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 
     try:
       while not coord.should_stop():
         for i in range():
           image_bth, _ = sess.run([images, labels])
           print(_)
 
           cv2.imshow("image_0", image_bth[0])
           cv2.imshow("image_1", image_bth[1])
         break
     except tf.errors.OutOfRangeError:
       print('read done')
     finally:
       coord.request_stop()
     coord.join(threads)
     cv2.waitKey(0)
     cv2.destroyAllWindows()
 
 if __name__ == "__main__":
   main()

到此这篇关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的文章就介绍到这了,更多相关Tensorflow TFRecord打包与读取内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python入门篇之字典
Oct 17 Python
利用ctypes提高Python的执行速度
Sep 09 Python
python Spyder界面无法打开的解决方法
Apr 27 Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 Python
python实现简单五子棋游戏
Jun 18 Python
python机器学习库scikit-learn:SVR的基本应用
Jun 26 Python
Python整数对象实现原理详解
Jul 01 Python
深入浅析Python 命令行模块 Click
Mar 11 Python
Python生成器generator原理及用法解析
Jul 20 Python
python与c语言的语法有哪些不一样的
Sep 13 Python
彻底解决pip下载pytorch慢的问题方法
Mar 01 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 Python
使用Tensorflow-GPU禁用GPU设置(CPU与GPU速度对比)
Jun 30 #Python
keras的backend 设置 tensorflow,theano操作
Jun 30 #Python
浅谈TensorFlow中读取图像数据的三种方式
Jun 30 #Python
python中 _、__、__xx__()区别及使用场景
Jun 30 #Python
Django实现内容缓存实例方法
Jun 30 #Python
Pytorch 卷积中的 Input Shape用法
Jun 29 #Python
Python闭包装饰器使用方法汇总
Jun 29 #Python
You might like
在JavaScript中调用php程序
2009/03/09 PHP
在Windows系统上安装PHP运行环境文字教程
2010/07/19 PHP
深入PHP运行环境配置的详解
2013/06/04 PHP
PHP实现图片裁剪、添加水印效果代码
2014/10/01 PHP
php分页查询的简单实现代码
2017/03/14 PHP
Laravel如何创建服务器提供者实例代码
2019/04/15 PHP
PHP架构及原理知识点详解
2019/12/22 PHP
禁止刷新,回退的JS
2006/11/25 Javascript
Jquery Ajax 学习实例2 向页面发出请求 返回JSon格式数据
2010/03/15 Javascript
cnblogs TagCloud基于jquery的实现代码
2010/06/11 Javascript
详解AngularJS中module模块的导入导出
2015/12/10 Javascript
JavaScript数据结构与算法之集合(Set)
2016/01/29 Javascript
举例说明JavaScript中的实例对象与原型对象
2016/03/11 Javascript
jQuery easyUI datagrid 增加求和统计行的实现代码
2016/06/01 Javascript
深入理解Javascript中的valueOf与toString
2017/01/04 Javascript
基于JavaScript实现表格滚动分页
2017/11/22 Javascript
js构造函数创建对象是否加new问题
2018/01/22 Javascript
使用D3.js创建物流地图的示例代码
2018/01/27 Javascript
基于Vue中点击组件外关闭组件的实现方法
2018/03/06 Javascript
JavaScript中使用import 和require打包后实现原理分析
2018/03/07 Javascript
Node爬取大批量文件的方法示例
2019/06/28 Javascript
vue项目使用$router.go(-1)返回时刷新原来的界面操作
2020/07/26 Javascript
vue 遮罩层阻止默认滚动事件操作
2020/07/28 Javascript
openLayer4实现动态改变标注图标
2020/08/17 Javascript
解决vue scoped html样式无效的问题
2020/10/24 Javascript
vue编写简单的购物车功能
2021/01/08 Vue.js
python通过floor函数舍弃小数位的方法
2015/03/17 Python
使用Python脚本zabbix自定义key监控oracle连接状态
2019/08/28 Python
python面向对象之类属性和类方法案例分析
2019/12/30 Python
解决pytorch-yolov3 train 报错的问题
2020/02/18 Python
python程序实现BTC(比特币)挖矿的完整代码
2021/01/20 Python
main 函数执行以前,还会执行什么代码
2013/04/17 面试题
历史学专业大学生找工作的自我评价
2013/10/16 职场文书
工作会议方案
2014/05/21 职场文书
2019邀请函格式及范文
2019/05/20 职场文书
Spring 使用注解开发
2022/05/20 Java/Android