tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 图片验证码代码分享
Jul 04 Python
对python抓取需要登录网站数据的方法详解
May 21 Python
python一行sql太长折成多行并且有多个参数的方法
Jul 19 Python
python获取微信小程序手机号并绑定遇到的坑
Nov 19 Python
对python的unittest架构公共参数token提取方法详解
Dec 17 Python
python找出一个列表中相同元素的多个索引实例
Jun 11 Python
Python MySQL 日期时间格式化作为参数的操作
Mar 02 Python
django model object序列化实例
Mar 13 Python
Django ForeignKey与数据库的FOREIGN KEY约束详解
May 20 Python
Django --Xadmin 判断登录者身份实例
Jul 03 Python
Python实现Excel自动分组合并单元格
Feb 22 Python
PyQt 如何创建自定义QWidget
Mar 24 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
php面向对象的方法重载两种版本比较
2008/09/08 PHP
php 文件缓存函数
2011/10/08 PHP
利用PHP实现短域名互转
2013/07/05 PHP
如何用PHP来实现一个动态Web服务器
2015/07/29 PHP
在openSUSE42.1下编译安装PHP7 的方法
2015/12/24 PHP
Thinkphp开发--集成极光推送
2017/09/15 PHP
效率高的Javscript字符串替换函数的benchmark
2008/08/02 Javascript
ajax页面无刷新 IE下遭遇Ajax缓存导致数据不更新的问题
2012/12/11 Javascript
Seajs的学习笔记
2014/03/04 Javascript
JavaScript中使用Callback控制流程介绍
2015/03/16 Javascript
jQuery实现Meizu魅族官方网站的导航菜单效果
2015/09/14 Javascript
利用jquery禁止外层滚动条的滚动
2017/01/05 Javascript
详解vue2.0监听属性的使用心得及搭配计算属性的使用
2018/07/18 Javascript
小程序scroll-view安卓机隐藏横向滚动条的实现详解
2019/05/16 Javascript
vue 组件中使用 transition 和 transition-group实现过渡动画
2019/07/09 Javascript
关于vue 结合原生js 解决echarts resize问题
2020/07/26 Javascript
[00:36]DOTA2上海特级锦标赛 Archon战队宣传片
2016/03/04 DOTA
python实现判断数组是否包含指定元素的方法
2015/07/15 Python
详解python eval函数的妙用
2017/11/16 Python
Python实现基于TCP UDP协议的IPv4 IPv6模式客户端和服务端功能示例
2018/03/22 Python
Python 实现两个列表里元素对应相乘的方法
2018/11/14 Python
python Selenium实现付费音乐批量下载的实现方法
2019/01/24 Python
对dataframe数据之间求补集的实例详解
2019/01/30 Python
Python高级property属性用法实例分析
2019/11/19 Python
Python完全识别验证码自动登录实例详解
2019/11/24 Python
Python 根据数据模板创建shapefile的实现
2019/11/26 Python
Python3安装模块报错Microsoft Visual C++ 14.0 is required的解决方法
2020/07/28 Python
外企求职信范文分享
2013/12/31 职场文书
实习鉴定评语
2014/01/19 职场文书
高中物理教学反思
2014/02/08 职场文书
高中毕业典礼演讲稿
2014/09/09 职场文书
财政局党的群众路线教育实践活动整改方案
2014/09/21 职场文书
司法局2014法制宣传日活动总结
2014/11/01 职场文书
呼啸山庄读书笔记
2015/06/29 职场文书
浅谈:电影《孔子》观后感(范文)
2019/10/14 职场文书
python识别围棋定位棋盘位置
2021/07/26 Python