tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取央视节目单的实现代码
Jul 25 Python
Python中使用Queue和Condition进行线程同步的方法
Jan 19 Python
分析Python中设计模式之Decorator装饰器模式的要点
Mar 02 Python
python Django批量导入不重复数据
Mar 25 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
使用Python爬了4400条淘宝商品数据,竟发现了这些“潜规则”
Mar 23 Python
opencv python统计及绘制直方图的方法
Jan 21 Python
python使用KNN算法识别手写数字
Apr 25 Python
python实现智能语音天气预报
Dec 02 Python
tensorflow实现在函数中用tf.Print输出中间值
Jan 21 Python
vue常用指令代码实例总结
Mar 16 Python
opencv 图像轮廓的实现示例
Jul 08 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
PHP的加密方式及原理
2012/06/14 PHP
php获取本地图片文件并生成xml文件输出具体思路
2013/04/27 PHP
深入mysql_fetch_row()与mysql_fetch_array()的区别详解
2013/06/05 PHP
用PHP解决的一个栈的面试题
2014/07/02 PHP
原生js和jquery中有关透明度设置的相关问题
2014/01/08 Javascript
浅谈jQuery中 wrap() wrapAll() 与 wrapInner()的差异
2014/11/12 Javascript
jQuery插件windowScroll实现单屏滚动特效
2015/07/14 Javascript
JS判断Android、iOS或浏览器的多种方法(四种方法)
2017/06/29 Javascript
vue跨域解决方法
2017/10/15 Javascript
详解Vue组件实现tips的总结
2017/11/01 Javascript
vue+webpack实现异步加载三种用法示例详解
2018/04/24 Javascript
vue.js 图片上传并预览及图片更换功能的实现代码
2018/08/27 Javascript
解决eclipse中没有js代码提示的问题
2018/10/10 Javascript
快速解决layui弹窗按enter键不停弹窗的问题
2019/09/18 Javascript
公众号SVG动画交互实战代码
2020/05/31 Javascript
简析Python的闭包和装饰器
2016/02/26 Python
深入学习Python中的装饰器使用
2016/06/20 Python
安装Python和pygame及相应的环境变量配置(图文教程)
2017/06/04 Python
详解用Python练习画个美队盾牌
2019/03/23 Python
Django应用程序入口WSGIHandler源码解析
2019/08/05 Python
python实现图片插入文字
2019/11/26 Python
通过实例解析python subprocess模块原理及用法
2020/10/10 Python
python中PyQuery库用法分享
2021/01/15 Python
世界上最大的餐具公司:Oneida
2016/12/17 全球购物
美国网上书店:Barnes & Noble
2018/08/15 全球购物
Clarks其乐鞋荷兰官网:Clarks荷兰
2019/07/05 全球购物
母亲追悼会答谢词
2014/01/27 职场文书
迎八一活动主题
2014/01/31 职场文书
初中班主任经验交流材料
2014/05/16 职场文书
幼儿园秋季开学寄语
2014/08/02 职场文书
2014年作风建设心得体会
2014/10/22 职场文书
工作检讨书大全
2015/01/26 职场文书
2015年项目工作总结
2015/04/29 职场文书
2016三严三实专题教育活动心得体会
2016/01/06 职场文书
Vue接口封装的完整步骤记录
2021/05/14 Vue.js
MySQL的全局锁和表级锁的具体使用
2021/08/23 MySQL