tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中引用与复制用法实例分析
Jun 04 Python
在Mac OS上搭建Python的开发环境
Dec 24 Python
Python判断列表是否已排序的各种方法及其性能分析
Jun 20 Python
python 线程的暂停, 恢复, 退出详解及实例
Dec 06 Python
matplotlib中legend位置调整解析
Dec 19 Python
python将文本中的空格替换为换行的方法
Mar 19 Python
2019 Python最新面试题及答案16道题
Apr 11 Python
pytorch forward两个参数实例
Jan 17 Python
使用python求解二次规划的问题
Feb 29 Python
详解Python中string模块除去Str还剩下什么
Nov 30 Python
利用Python+OpenCV三步去除水印
May 28 Python
如何用Python搭建gRPC服务
Jun 30 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
关于我转生变成史莱姆这档事:第二季PV上线,萌王2021年回归
2020/05/06 日漫
php中iconv函数使用方法
2008/05/24 PHP
php 远程图片保存到本地的函数类
2008/12/08 PHP
PHP中exec函数和shell_exec函数的区别
2014/08/20 PHP
php curl模拟post请求和提交多维数组的示例代码
2015/11/19 PHP
tp5(thinkPHP5框架)时间查询操作实例分析
2019/05/29 PHP
利用404错误页面实现UrlRewrite的实现代码
2008/08/20 Javascript
载入jQuery库的最佳方法详细说明及实现代码
2012/12/28 Javascript
javascript中setTimeout的问题解决方法
2014/05/08 Javascript
JavaScript获取网页中第一个图片id的方法
2015/04/03 Javascript
jqGrid 学习笔记整理——进阶篇(一 )
2016/04/17 Javascript
基于vue的fullpage.js单页滚动插件
2017/03/20 Javascript
JS实现侧边栏鼠标经过弹出框+缓冲效果
2017/03/29 Javascript
JavaScript闭包和回调详解
2017/08/09 Javascript
关于vue-router的beforeEach无限循环的问题解决
2017/09/09 Javascript
webpack将js打包后的map文件详解
2018/02/22 Javascript
小程序:授权、登录、session_key、unionId的详解
2019/05/15 Javascript
如何基于JS截获动态代码
2019/12/25 Javascript
vue 在methods中调用mounted的实现操作
2020/08/07 Javascript
Python运算符重载用法实例分析
2015/06/01 Python
python 随机生成10位数密码的实现代码
2019/06/27 Python
pymysql 插入数据 转义处理方式
2020/03/02 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
2020/06/08 Python
解决HTML5手机端页面缩放的问题
2017/10/27 HTML / CSS
Volcom法国官网:美国冲浪滑板品牌
2017/05/25 全球购物
英国时尚高尔夫服装购物网站:Trendy Golf
2020/01/10 全球购物
英国马莎百货印度官网:Marks & Spencer印度
2020/10/08 全球购物
璀璨的珍珠、密钉和个性化珠宝:Lily & Roo
2021/01/21 全球购物
华三通信H3C面试题
2015/05/15 面试题
学生周末长期请假条
2014/02/15 职场文书
安全在我心中演讲稿
2014/09/01 职场文书
领导干部作风建设自查报告
2014/10/23 职场文书
青年文明号申报材料
2014/12/23 职场文书
乒乓球比赛通知
2015/04/27 职场文书
2015教师年度思想工作总结
2015/04/30 职场文书
基于MySql验证的vsftpd虚拟用户
2021/11/07 MySQL