tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过socket查询whois的方法
Jul 18 Python
Linux下通过python访问MySQL、Oracle、SQL Server数据库的方法
Apr 23 Python
Python彩色化Linux的命令行终端界面的代码实例分享
Jul 02 Python
Python自动生产表情包
Mar 17 Python
Python中模块pymysql查询结果后如何获取字段列表
Jun 05 Python
python自定义异常实例详解
Jul 11 Python
Python网络爬虫神器PyQuery的基本使用教程
Feb 03 Python
Python 利用切片从列表中取出一部分使用的方法
Feb 01 Python
Python分割训练集和测试集的方法示例
Sep 19 Python
浅谈JupyterNotebook导出pdf解决中文的问题
Apr 22 Python
python用opencv 图像傅里叶变换
Jan 04 Python
Python实现我的世界小游戏源代码
Mar 02 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
PHP 数组实例说明
2008/08/18 PHP
重新认识php array_merge函数
2014/08/31 PHP
ajax调用返回php接口返回json数据的方法(必看篇)
2017/05/05 PHP
laravel框架语言包拓展实现方法分析
2019/11/22 PHP
JavaScript confirm选择判断
2008/10/18 Javascript
图片上传即时显示缩略图的js代码
2009/05/27 Javascript
JS 日期验证正则附asp日期格式化函数
2009/09/11 Javascript
Firefox中beforeunload事件的实现缺陷浅析
2012/05/03 Javascript
IE8下Jquery获取select选中的值post到后台报错问题
2014/07/02 Javascript
浅谈JavaScript Math和Number对象
2015/01/26 Javascript
详谈javascript中的cookie
2015/06/03 Javascript
javascript下使用Promise封装FileReader
2016/02/19 Javascript
JS实现仿PS的调色板效果完整实例
2016/12/21 Javascript
AngularJS constant和value区别详解
2017/02/28 Javascript
VUE简单的定时器实时刷新的实现方法
2019/01/20 Javascript
详解vue 自定义marquee无缝滚动组件
2019/04/09 Javascript
vuex 中插件的编写案例解析
2019/06/10 Javascript
[51:32]Optic vs Serenity 2018国际邀请赛淘汰赛BO3 第一场 8.22
2018/08/23 DOTA
Python压缩和解压缩zip文件
2015/02/14 Python
Python多线程编程(三):threading.Thread类的重要函数和方法
2015/04/05 Python
Python对象转JSON字符串的方法
2016/04/27 Python
OpenCV-Python实现轮廓检测实例分析
2018/01/05 Python
在NumPy中创建空数组/矩阵的方法
2018/06/15 Python
python3的输入方式及多组输入方法
2018/10/17 Python
python运行时强制刷新缓冲区的方法
2019/01/14 Python
使用Pyinstaller转换.py文件为.exe可执行程序过程详解
2019/08/06 Python
python虚拟环境完美部署教程
2019/08/06 Python
详解Anconda环境下载python包的教程(图形界面+命令行+pycharm安装)
2019/11/11 Python
python传到前端的数据,双引号被转义的问题
2020/04/03 Python
html5配合css3实现带提示文字的输入框(摆脱js)
2013/03/08 HTML / CSS
美国在线打印网站:Overnight Prints
2018/10/11 全球购物
餐饮收银员岗位职责
2014/02/07 职场文书
社区志愿者活动总结
2014/06/26 职场文书
2014标准社保办理委托书
2014/10/06 职场文书
用人单位的规章制度,怎样制定才是有效的?
2019/07/09 职场文书
Java异常处理try catch的基本用法
2021/12/06 Java/Android