tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python导入txt数据到mysql的方法
Apr 08 Python
Python的字典和列表的使用中一些需要注意的地方
Apr 24 Python
python抓取文件夹的所有文件
Feb 27 Python
python自动重试第三方包retrying模块的方法
Apr 24 Python
pandas.loc 选取指定列进行操作的实例
May 18 Python
Jupyter中直接显示Matplotlib的图形方法
May 24 Python
解决python测试opencv时imread导致的错误问题
Jan 26 Python
Python中新式类与经典类的区别详析
Jul 10 Python
pytorch 模型可视化的例子
Aug 17 Python
face++与python实现人脸识别签到(考勤)功能
Aug 28 Python
基于python爬取有道翻译过程图解
Mar 31 Python
python导入库的具体方法
Jun 18 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
php UTF8 文件的签名问题
2009/10/30 PHP
30 个很棒的PHP开源CMS内容管理系统小结
2011/10/14 PHP
Yii中使用PHPExcel导出Excel的方法
2014/12/26 PHP
PHP正则验证字符串是否为数字的两种方法并附常用正则
2019/02/27 PHP
用js统计用户下载网页所需时间的脚本
2008/10/15 Javascript
Js 获取当前日期时间及其它操作实现代码
2021/03/04 Javascript
JS实现根据当前文字选择返回被选中的文字
2014/05/21 Javascript
express的中间件bodyParser详解
2014/12/04 Javascript
JS绘制生成花瓣效果的方法
2015/08/05 Javascript
JavaScript操作选择对象的简单实例
2016/05/16 Javascript
javascript中replace使用方法总结
2017/03/01 Javascript
js实现一个猜数字游戏
2017/03/31 Javascript
Vue中父组件向子组件通信的方法
2017/07/11 Javascript
JavaScript文件的同步和异步加载的实现代码
2017/08/19 Javascript
vue+Java后端进行调试时解决跨域问题的方式
2017/10/19 Javascript
分享ES6的7个实用技巧
2018/01/18 Javascript
nodejs npm错误Error:UNKNOWN:unknown error,mkdir 'D:\Develop\nodejs\node_global'at Error
2019/03/02 NodeJs
微信小程序传值以及获取值方法的详解
2019/04/29 Javascript
D3.js的基础部分之数组的处理数组的排序和求值(v3版本)
2019/05/09 Javascript
解决vue项目F5刷新mounted里的函数不执行问题
2019/11/05 Javascript
JavaScript中Object、map、weakmap的区别分析
2020/12/15 Javascript
vue二选一tab栏切换新做法实现
2021/01/19 Vue.js
Python的Django框架中的select_related函数对QuerySet 查询的优化
2015/04/01 Python
Python用imghdr模块识别图片格式实例解析
2018/01/11 Python
利用Pandas 创建空的DataFrame方法
2018/04/08 Python
Python迭代器与生成器基本用法分析
2018/07/26 Python
pytorch中的自定义数据处理详解
2020/01/06 Python
python next()和iter()函数原理解析
2020/02/07 Python
pycharm设置python文件模板信息过程图解
2020/03/10 Python
python自定义函数def的应用详解
2020/06/03 Python
Python爬虫之Selenium设置元素等待的方法
2020/12/04 Python
西班牙高科技产品购物网站:MejorDeseo
2019/09/08 全球购物
Marc O’Polo俄罗斯官方在线商店:德国高端时尚品牌
2019/12/26 全球购物
网络维护中文求职信
2014/01/03 职场文书
优秀护士先进事迹
2014/05/08 职场文书
2014年旅游局法制宣传日活动总结
2014/11/01 职场文书