tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python MD5文件生成码
Jan 12 Python
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
Python max内置函数详细介绍
Nov 17 Python
Python基础学习之常见的内建函数整理
Sep 06 Python
Python的numpy库中将矩阵转换为列表等函数的方法
Apr 04 Python
python3爬取数据至mysql的方法
Jun 26 Python
python爬虫爬取微博评论案例详解
Mar 27 Python
python线程的几种创建方式详解
Aug 29 Python
pycharm进入时每次都是insert模式的解决方式
Feb 05 Python
Python Numpy之linspace用法说明
Apr 17 Python
python实现剪贴板的操作
Jul 01 Python
讲解Python实例练习逆序输出字符串
May 06 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
[EPIC] Larva vs Flash ZvT @ Crossing Field [2017-10-09]
2020/03/17 星际争霸
PHP5中使用PDO连接数据库的方法
2010/08/01 PHP
php查询内存信息操作示例
2019/05/09 PHP
ImageZoom 图片放大镜效果(多功能扩展篇)
2010/04/14 Javascript
对setInterval在火狐和chrome切换标签产生奇怪的效果之探索,与解决方案!
2011/10/29 Javascript
解决用jquery load加载页面到div时,不执行页面js的问题
2014/02/22 Javascript
JavaScript利用正则表达式去除日期中的-
2014/06/09 Javascript
JavaScript中的small()方法使用详解
2015/06/08 Javascript
TypeScript学习之强制类型的转换
2016/12/27 Javascript
深入理解node.js之path模块
2017/05/03 Javascript
详解vue-router 2.0 常用基础知识点之router-link
2017/05/10 Javascript
原生JS获取元素的位置与尺寸实现方法
2017/10/18 Javascript
详解Vue快速零配置的打包工具——parcel
2018/01/16 Javascript
ES6 中可以提升幸福度的小功能
2018/08/06 Javascript
Vue 2.0双向绑定原理的实现方法
2019/10/23 Javascript
vuex actions异步修改状态的实例详解
2019/11/06 Javascript
如何使用 JavaScript 操作浏览器历史记录 API
2020/11/24 Javascript
[46:44]VG vs TNC Supermajor小组赛B组败者组决赛 BO3 第一场 6.2
2018/06/03 DOTA
[01:30:55]VG vs Mineski Supermajor 败者组 BO3 第三场 6.6
2018/06/07 DOTA
Python 自动安装 Rising 杀毒软件
2009/04/24 Python
用python实现批量重命名文件的代码
2012/05/25 Python
Python获取电脑硬件信息及状态的实现方法
2014/08/29 Python
Python的“二维”字典 (two-dimension dictionary)定义与实现方法
2016/04/27 Python
python用reduce和map把字符串转为数字的方法
2016/12/19 Python
python中将函数赋值给变量时需要注意的一些问题
2017/08/18 Python
wxPython多个窗口的基本结构
2019/11/19 Python
浅谈Python 函数式编程
2020/06/20 Python
浅析Python 字符编码与文件处理
2020/09/24 Python
CSS3 对过渡(transition)进行调速以及延时
2020/10/21 HTML / CSS
医学检验专业个人求职信范文
2013/12/04 职场文书
优秀毕业生求职信范文
2014/01/02 职场文书
东京审判观后感
2015/06/01 职场文书
地道战观后感
2015/06/04 职场文书
springboot中一些比较常用的注解总结
2021/06/11 Java/Android
Python实现生活常识解答机器人
2021/06/28 Python
Python pandas求方差和标准差的方法实例
2021/08/04 Python