利用Tensorflow的队列多线程读取数据方式


Posted in Python onFebruary 05, 2020

在tensorflow中,有三种方式输入数据

1. 利用feed_dict送入numpy数组

2. 利用队列从文件中直接读取数据

3. 预加载数据

其中第一种方式很常用,在tensorflow的MNIST训练源码中可以看到,通过feed_dict={},可以将任意数据送入tensor中。

第二种方式相比于第一种,速度更快,可以利用多线程的优势把数据送入队列,再以batch的方式出队,并且在这个过程中可以很方便地对图像进行随机裁剪、翻转、改变对比度等预处理,同时可以选择是否对数据随机打乱,可以说是非常方便。该部分的源码在tensorflow官方的CIFAR-10训练源码中可以看到,但是对于刚学习tensorflow的人来说,比较难以理解,本篇博客就当成我调试完成后写的一篇总结,以防自己再忘记具体细节。

读取CIFAR-10数据集

按照第一种方式的话,CIFAR-10的读取只需要写一段非常简单的代码即可将测试集与训练集中的图像分别读取:

path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
# extract train examples
num_train_examples = 50000
x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')
y_train = np.empty((num_train_examples), dtype='uint8')
for i in range(1, 6): 
 fpath = os.path.join(path, 'data_batch_' + str(i)) 
 (x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000])   = load_and_decode(fpath)
# extract test examples
fpath = os.path.join(path, 'test_batch')
x_test, y_test = load_and_decode(fpath)
return x_train, y_train, x_test, np.array(y_test)

其中load_and_decode函数只需要按照CIFAR-10官网给出的方式decode就行,最终返回的x_train是一个[50000, 32, 32, 3]的ndarray,但对于ndarray来说,进行预处理就要麻烦很多,为了取mini-SGD的batch,还自己写了一个类,通过调用train_set.next_batch()函数来取,总而言之就是什么都要自己动手,效率确实不高

但对于第二种方式,读取起来就要麻烦很多,但使用起来,又快又方便

首先,把CIFAR-10的测试集文件读取出来,生成文件名列表

path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]

有了列表以后,利用tf.train.string_input_producer函数生成一个读取队列

filename_queue = tf.train.string_input_producer(filenames)

接下来,我们调用read_cifar10函数,得到一幅一幅的图像,该函数的代码如下:

def read_cifar10(filename_queue):
 label_bytes = 1
 IMAGE_SIZE = 32
 CHANNELS = 3
 image_bytes = IMAGE_SIZE*IMAGE_SIZE*3
 record_bytes = label_bytes+image_bytes

 # define a reader
 reader = tf.FixedLengthRecordReader(record_bytes)
 key, value = reader.read(filename_queue)
 record_bytes = tf.decode_raw(value, tf.uint8)

 label = tf.strided_slice(record_bytes, [0], [label_bytes])
 depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],  
            [label_bytes + image_bytes]),
        [CHANNELS, IMAGE_SIZE, IMAGE_SIZE])
 image = tf.transpose(depth_major, [1, 2, 0])
 return image, label

第9行,定义一个reader,来读取固定长度的数据,这个固定长度是由CIFAR-10数据集图片的存储格式决定的,1byte的标签加上32 *32 *3长度的图像,3代表RGB三通道,由于图片的是按[channel, height, width]的格式存储的,为了变为常用的[height, width, channel]维度,需要在17行reshape一次图像,最终我们提取出了一副完整的图像与对应的标签

对图像进行预处理

我们取出的image与label均为tensor格式,因此预处理将变得非常简单

if not distortion:
  IMAGE_SIZE = 32
 else:
  IMAGE_SIZE = 24
  # 随机裁剪为24*24大小
  distorted_image = tf.random_crop(tf.cast(image, tf.float32), [IMAGE_SIZE, IMAGE_SIZE, 3])
  # 随机水平翻转
  distorted_image = tf.image.random_flip_left_right(distorted_image)
  # 随机调整亮度
  distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
  # 随机调整对比度
  distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
  # 对图像进行白化操作,即像素值转为零均值单位方差
  float_image = tf.image.per_image_standardization(distorted_image)

distortion是定义的一个输入布尔型变量,默认为True,表示是否对图像进行处理

填充队列与随机打乱

调用tf.train.shuffle_batch或tf.train.batch函数,以tf.train.shuffle_batch为例,函数的定义如下:

def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
     num_threads=1, seed=None, enqueue_many=False, shapes=None,
     allow_smaller_final_batch=False, shared_name=None, name=None):

tensors表示输入的张量(tensor),batch_size表示要输出的batch的大小,capacity表示队列的容量,即大小,min_after_dequeue表示出队操作后队列中的最小元素数量,这个值是要小于队列的capacity的,通过调整min_after_dequeue与capacity两个变量,可以改变数据被随机打乱的程度,num_threads表示使用的线程数,只要取大于1的数,队列的效率就会高很多。

通常情况下,我们只需要输入以上几个变量即可,在CIFAR-10_input.py中,谷歌给出的代码是这样写的:

if shuffle:
 images, label_batch = tf.train.shuffle_batch([image, label], batch_size,         min_queue_examples+3*batch_size,
       min_queue_examples, num_preprocess_threads)
else:
 images, label_batch = tf.train.batch([image, label], batch_size,
           num_preprocess_threads, 
           min_queue_examples + 3 * batch_size)

min_queue_examples由以下方式得到:

min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 
       *min_fraction_of_examples_in_queue)

当然,这些值均可以自己随意设置,

最终得到的images,labels(label_batch),即为shape=[128, 32, 32, 3]的tensor,其中128为默认batch_size。

激活队列与处理异常

得到了images和labels两个tensor后,我们便可以把这两个tensor送入graph中进行运算了

# input tensor
img_batch, label_batch = cifar10_input.tesnsor_shuffle_input(batch_size)

# build graph that computes the logits predictions from the inference model
logits, predicts = train.inference(img_batch, keep_prob)

# calculate loss
loss = train.loss(logits, label_batch)

定义sess=tf.Session()后,运行sess.run(),然而你会发现并没有输出,程序直接挂起了,仿佛死掉了一样

原因是这样的,虽然我们在数据流图中加入了队列,但只有调用tf.train.start_queue_runners()函数后,数据才会动起来,被负责输入管道的线程填入队列,否则队列将会挂起。

OK,我们调用函数,让队列运行起来

with tf.Session(config=run_config) as sess:
 sess.run(init_op) # intialization
 queue_runner = tf.train.start_queue_runners(sess)
 for i in range(10):
  b1, b2 = sess.run([img_batch, label_batch])
  print(b1.shape)

在这里为了测试,我们取10次输出,看看输出的batch1的维度是否正确

利用Tensorflow的队列多线程读取数据方式

10个batch的维度均为正确的,但是tensorflow却报了错,错误的文字内容如下:

2017-12-19 16:40:56.429687: W C:\tf_jenkins\home\workspace\rel-win\M\windows-gpu\PY\36\tensorflow\core\kernels\queue_base.cc:295] _ 0 _ input_producer: Skipping cancelled enqueue attempt with queue not closed

简单地看一下,大致意思是说我们的队列里还有数据,但是程序结束了,抛出了异常,因此,我们还需要定义一个Coordinator,也就是协调器来处理异常

Coordinator有3个主要方法:

1. tf.train.Coordinator.should_stop() 如果线程应该停止,返回True

2. tf.train.Coordinator.request_stop() 请求停止线程

3. tf.train.Coordinator.join() 等待直到指定线程停止

首先,定义协调器

coord = tf.train.Coordinator()

将协调器应用于QueueRunner

queue_runner = tf.train.start_queue_runners(sess, coord=coord)

结束数据的训练或测试后,关闭线程

coord.request_stop()
coord.join(queue_runner)

最终的sess代码段如下:

coord = tf.train.Coordinator()
with tf.Session(config=run_config) as sess:
 sess.run(init_op)
 queue_runner = tf.train.start_queue_runners(sess, coord=coord)
 for i in range(10):
  b1, b2 = sess.run([img_batch, label_batch])
  print(b1.shape)
 coord.request_stop()
 coord.join(queue_runner)

得到的输出结果为:

利用Tensorflow的队列多线程读取数据方式

完美解决,利用img_batch与label_batch,把tensor送入graph中,就可以享受tensorflow带来的训练乐趣了

以上这篇利用Tensorflow的队列多线程读取数据方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 算法 排序实现快速排序
Jun 05 Python
解决python写的windows服务不能启动的问题
Apr 15 Python
python实现应用程序在右键菜单中添加打开方式功能
Jan 09 Python
Pandas 数据框增、删、改、查、去重、抽样基本操作方法
Apr 12 Python
浅谈numpy数组中冒号和负号的含义
Apr 18 Python
解决Mac安装scrapy失败的问题
Jun 13 Python
python 实现敏感词过滤的方法
Jan 21 Python
python3+selenium实现126邮箱登陆并发送邮件功能
Jan 23 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
Python3 requests文件下载 期间显示文件信息和下载进度代码实例
Aug 16 Python
python实现用类读取文件数据并计算矩形面积
Jan 18 Python
Python数据可视化处理库PyEcharts柱状图,饼图,线性图,词云图常用实例详解
Feb 10 Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
You might like
PHP与SQL注入攻击[一]
2007/04/17 PHP
ThinkPHP与PHPExcel冲突解决方法
2011/08/08 PHP
php按单词截取字符串的方法
2015/04/07 PHP
PHP实现的各类hash算法长度及性能测试实例
2017/08/27 PHP
PHP实现基于栈的后缀表达式求值功能
2017/11/10 PHP
PHP实现随机数字、字母的验证码功能
2018/08/01 PHP
浅析JavaScript中的常用算法与函数
2013/11/21 Javascript
浅谈document.write()输出样式
2015/05/07 Javascript
JavaScript中Date对象的常用方法示例
2015/10/24 Javascript
JavaScript面向对象之私有静态变量实例分析
2016/01/14 Javascript
深入理解Javascript中的观察者模式
2017/02/20 Javascript
vue.js组件之间传递数据的方法
2017/07/10 Javascript
AngularJS动态添加数据并删除的实例
2018/02/27 Javascript
js array数组对象操作方法汇总
2019/03/18 Javascript
Vue实现页面添加水印功能
2019/11/09 Javascript
JavaScript或jQuery 获取option value值方法解析
2020/05/12 jQuery
JS倒计时两种实现方式代码实例
2020/07/27 Javascript
关于小程序优化的一些建议(小结)
2020/12/10 Javascript
[54:28]EG vs OG 2019国际邀请赛小组赛 BO2 第一场 8.16
2019/08/18 DOTA
Python sys.path详细介绍
2013/10/17 Python
Python 实现数据库(SQL)更新脚本的生成方法
2017/07/09 Python
Linux下python制作名片示例
2018/07/20 Python
centos6.8安装python3.7无法import _ssl的解决方法
2018/09/17 Python
python实现一组典型数据格式转换
2018/12/15 Python
浅析PyTorch中nn.Linear的使用
2019/08/18 Python
pytorch 常用函数 max ,eq说明
2020/06/28 Python
使用npy转image图像并保存的实例
2020/07/01 Python
Python-split()函数实例用法讲解
2020/12/18 Python
物流专业大学生求职信范文
2013/10/28 职场文书
《跟踪台风的卫星》教学反思
2014/04/10 职场文书
医院合作协议书
2014/08/19 职场文书
实习科室评语
2015/01/04 职场文书
会计出纳岗位职责
2015/03/31 职场文书
员工福利申请报告
2015/05/15 职场文书
2015年大学组织委员个人工作总结
2015/10/23 职场文书
Python如何配置环境变量详解
2021/05/18 Python