用十张图详解TensorFlow数据读取机制(附代码)


Posted in Python onFebruary 06, 2018

在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解。确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料。今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下TensorFlow的数据读取机制,文章的最后还会给出实战代码以供参考。

TensorFlow读取机制图解

首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取数据的过程可以用下图来表示:

用十张图详解TensorFlow数据读取机制(附代码)

假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003.jpg……我们只需要把它们读取到内存中,然后提供给GPU或是CPU进行计算就可以了。这听起来很容易,但事实远没有那么简单。事实上,我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率。

如何解决这个问题?方法就是将读入数据和计算分别放在两个线程中,将数据读入内存的一个队列,如下图所示:

用十张图详解TensorFlow数据读取机制(附代码)

读取线程源源不断地将文件系统中的图片读入到一个内存的队列中,而负责计算的是另一个线程,计算需要数据时,直接从内存队列中取就可以了。这样就可以解决GPU因为IO而空闲的问题!

而在TensorFlow中,为了方便管理,在内存队列前又添加了一层所谓的“文件名队列”。

为什么要添加这一层文件名队列?我们首先得了解机器学习中的一个概念:epoch。对于一个数据集来讲,运行一个epoch就是将这个数据集中的图片全部计算一遍。如一个数据集中有三张图片A.jpg、B.jpg、C.jpg,那么跑一个epoch就是指对A、B、C三张图片都计算了一遍。两个epoch就是指先对A、B、C各计算一遍,然后再全部计算一遍,也就是说每张图片都计算了两遍。

TensorFlow使用文件名队列+内存队列双队列的形式读入文件,可以很好地管理epoch。下面我们用图片的形式来说明这个机制的运行方式。如下图,还是以数据集A.jpg, B.jpg, C.jpg为例,假定我们要跑一个epoch,那么我们就在文件名队列中把A、B、C各放入一次,并在之后标注队列结束。

用十张图详解TensorFlow数据读取机制(附代码)

程序运行后,内存队列首先读入A(此时A从文件名队列中出队):

用十张图详解TensorFlow数据读取机制(附代码)

再依次读入B和C:

用十张图详解TensorFlow数据读取机制(附代码)

用十张图详解TensorFlow数据读取机制(附代码)

此时,如果再尝试读入,系统由于检测到了“结束”,就会自动抛出一个异常(OutOfRange)。外部捕捉到这个异常后就可以结束程序了。这就是TensorFlow中读取数据的基本机制。如果我们要跑2个epoch而不是1个epoch,那只要在文件名队列中将A、B、C依次放入两次再标记结束就可以了。

TensorFlow读取数据机制的对应函数

如何在TensorFlow中创建上述的两个队列呢?

对于文件名队列,我们使用tf.train.string_input_producer函数。这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。

此外tf.train.string_input_producer还有两个重要的参数,一个是num_epochs,它就是我们上文中提到的epoch数。另外一个就是shuffle,shuffle是指在一个epoch内文件的顺序是否被打乱。若设置shuffle=False,如下图,每个epoch内,数据还是按照A、B、C的顺序进入文件名队列,这个顺序不会改变:

用十张图详解TensorFlow数据读取机制(附代码)

如果设置shuffle=True,那么在一个epoch内,数据的前后顺序就会被打乱,如下图所示:

用十张图详解TensorFlow数据读取机制(附代码)

在TensorFlow中,内存队列不需要我们自己建立,我们只需要使用reader对象从文件名队列中读取数据就可以了,具体实现可以参考下面的实战代码。

除了tf.train.string_input_producer外,我们还要额外介绍一个函数:tf.train.start_queue_runners。初学者会经常在代码中看到这个函数,但往往很难理解它的用处,在这里,有了上面的铺垫后,我们就可以解释这个函数的作用了。

在我们使用tf.train.string_input_producer创建文件名队列后,整个系统其实还是处于“停滞状态”的,也就是说,我们文件名并没有真正被加入到队列中(如下图所示)。此时如果我们开始计算,因为内存队列中什么也没有,计算单元就会一直等待,导致整个系统被阻塞。

用十张图详解TensorFlow数据读取机制(附代码)

而使用tf.train.start_queue_runners之后,才会启动填充队列的线程,这时系统就不再“停滞”。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了,这就是函数tf.train.start_queue_runners的用处。

用十张图详解TensorFlow数据读取机制(附代码)

实战代码

我们用一个具体的例子感受TensorFlow中的数据读取。如图,假设我们在当前文件夹中已经有A.jpg、B.jpg、C.jpg三张图片,我们希望读取这三张图片5个epoch并且把读取的结果重新存到read文件夹中。

用十张图详解TensorFlow数据读取机制(附代码)

对应的代码如下:

# 导入TensorFlow
import TensorFlow as tf 

# 新建一个Session
with tf.Session() as sess:
  # 我们要读三幅图片A.jpg, B.jpg, C.jpg
  filename = ['A.jpg', 'B.jpg', 'C.jpg']
  # string_input_producer会产生一个文件名队列
  filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
  # reader从文件名队列中读数据。对应的方法是reader.read
  reader = tf.WholeFileReader()
  key, value = reader.read(filename_queue)
  # tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
  tf.local_variables_initializer().run()
  # 使用start_queue_runners之后,才会开始填充队列
  threads = tf.train.start_queue_runners(sess=sess)
  i = 0
  while True:
    i += 1
    # 获取图片数据并保存
    image_data = sess.run(value)
    with open('read/test_%d.jpg' % i, 'wb') as f:
      f.write(image_data)

我们这里使用filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)建立了一个会跑5个epoch的文件名队列。并使用reader读取,reader每次读取一张图片并保存。

运行代码后,我们得到就可以看到read文件夹中的图片,正好是按顺序的5个epoch:

用十张图详解TensorFlow数据读取机制(附代码)

如果我们设置filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)中的shuffle=True,那么在每个epoch内图像就会被打乱,如图所示:

用十张图详解TensorFlow数据读取机制(附代码)

我们这里只是用三张图片举例,实际应用中一个数据集肯定不止3张图片,不过涉及到的原理都是共通的。

实例:tensorflow读取图片的方法

下面讲解tensorflow如何读取jpg格式的图片,png格式的图片是一样的。有两种情况:

第一种就是把图片看做是一个图片直接读进来,获取图片的原始数据,再进行解码,主要用到的函数就是tf.gfile.FastGFile,tf.image.decode_jpeg

例如:

import tensorflow as tf;  
image_raw_data = tf.gfile.FastGFile('/home/penglu/Desktop/11.jpg').read() 
image = tf.image.decode_jpeg(image_raw_data) #图片解码 
print image.eval(session=tf.Session())

输出:

[[[ 11  63 110]
  [ 14  66 113]
  [ 17  69 116]
  ...,

第二种方式就是把图片看看成一个文件,用队列的方式读取

例如:

import tensorflow as tf;   
path = '/home/penglu/Desktop/11.jpg' 
file_queue = tf.train.string_input_producer([path]) #创建输入队列 
image_reader = tf.WholeFileReader() 
_, image = image_reader.read(file_queue) 
image = tf.image.decode_jpeg(image) 
 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() #协同启动的线程 
  threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列 
  print sess.run(image) 
  coord.request_stop() #停止所有的线程 
  coord.join(threads)

输出:

[[[ 11  63 110]
  [ 14  66 113]
  [ 17  69 116]
  ...,

总结

这篇文章主要用图解的方式详细介绍了TensorFlow读取数据的机制,最后还给出了对应的实战代码,希望能够给大家学习TensorFlow带来一些实质性的帮助。也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现控制台输入密码的方法
May 29 Python
Python实现图像几何变换
Jul 06 Python
Django中url的反向查询的方法
Mar 14 Python
python读取文件名称生成list的方法
Apr 27 Python
python smtplib发送带附件邮件小程序
May 22 Python
pandas把所有大于0的数设置为1的方法
Jan 26 Python
Python Selenium 之关闭窗口close与quit的方法
Feb 13 Python
django写用户登录判定并跳转制定页面的实例
Aug 21 Python
Python3实现mysql连接和数据框的形成(实例代码)
Jan 17 Python
Python autoescape标签用法解析
Jan 17 Python
解决更改AUTH_USER_MODEL后出现的问题
May 14 Python
python 窃取摄像头照片的实现示例
Jan 08 Python
Python实现matplotlib显示中文的方法详解
Feb 06 #Python
Python实现自动上京东抢手机
Feb 06 #Python
Python获取指定文件夹下的文件名的方法
Feb 06 #Python
TensorFlow如何实现反向传播
Feb 06 #Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
You might like
Windows下的PHP5.0安装配制详解
2006/09/05 PHP
php 获取完整url地址
2008/12/20 PHP
openflashchart 2.0 简单案例php版
2012/05/21 PHP
PHP的Yii框架中创建视图和渲染视图的方法详解
2016/03/29 PHP
在laravel中使用Symfony的Crawler组件分析HTML
2017/06/19 PHP
实用的Jquery选项卡TAB示例代码
2013/08/28 Javascript
jquery foreach使用示例
2013/09/12 Javascript
JsRender for index循环索引用法详解
2014/10/31 Javascript
js实现checkbox全选、不选与反选的方法
2015/02/09 Javascript
JQuery判断radio(单选框)是否选中和获取选中值方法总结
2015/04/15 Javascript
readonly和disabled属性的区别
2015/07/26 Javascript
输入法的回车与消息发送快捷键回车的冲突解决方法
2016/08/09 Javascript
JavaScript SHA1加密算法实现详细代码
2016/10/06 Javascript
Ionic+AngularJS实现登录和注册带验证功能
2017/02/09 Javascript
JS移动端/H5同时选择多张图片上传并使用canvas压缩图片
2017/06/20 Javascript
Angularjs添加排序查询功能的实例代码
2017/10/24 Javascript
webpack项目使用eslint建立代码规范实现
2019/05/16 Javascript
详解使用mocha对webpack打包的项目进行"冒烟测试"的大致流程
2020/04/27 Javascript
在Python程序中实现分布式进程的教程
2015/04/28 Python
django认证系统实现自定义权限管理的方法
2018/07/16 Python
python 数字类型和字符串类型的相互转换实例
2018/07/17 Python
对python产生随机的二维数组实例详解
2018/12/13 Python
python使用wxpy轻松实现微信防撤回的方法
2019/02/21 Python
Pytorch中Tensor与各种图像格式的相互转化详解
2019/12/26 Python
在django admin中配置搜索域是一个外键时的处理方法
2020/05/20 Python
pycharm使用技巧之自动调整代码格式总结
2020/11/04 Python
台湾专柜女包:KINAZ
2019/12/26 全球购物
《月亮湾》教学反思
2014/04/14 职场文书
不忘国耻振兴中华演讲稿
2014/05/14 职场文书
妇联主席先进事迹
2014/05/18 职场文书
酒店收银员岗位职责
2015/04/07 职场文书
大学考试作弊检讨书
2015/05/06 职场文书
音乐之声观后感
2015/06/04 职场文书
2015年医院保卫科工作总结
2015/07/23 职场文书
NodeJs内存占用过高的排查实战记录
2021/05/10 NodeJs
Golang原生rpc(rpc服务端源码解读)
2022/04/07 Golang