tensorflow使用range_input_producer多线程读取数据实例


Posted in Python onJanuary 20, 2020

先放关键代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...

队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

main.py内容:

import tensorflow as tf
import codecs
 
BATCH_SIZE = 6
NUM_EXPOCHES = 5
 
 
def input_producer():
 array = codecs.open("test.txt").readlines()
	array = map(lambda line: line.strip(), array)
 i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
 inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
 return inputs
 
 
class Inputs(object):
 def __init__(self):
  self.inputs = input_producer()
 
 
def main(*args, **kwargs):
 inputs = Inputs()
 init = tf.group(tf.initialize_all_variables(),
     tf.initialize_local_variables())
 sess = tf.Session()
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
 sess.run(init)
 try:
  index = 0
  while not coord.should_stop() and index<10:
   datalines = sess.run(inputs.inputs)
   index += 1
   print("step: %d, batch data: %s" % (index, str(datalines)))
 except tf.errors.OutOfRangeError:
  print("Done traing:-------Epoch limit reached")
 except KeyboardInterrupt:
  print("keyboard interrput detected, stop training")
 finally:
  coord.request_stop()
 coord.join(threads)
 sess.close()
 del sess
	
if __name__ == "__main__":
 main()

输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached

如果range_input_producer去掉参数num_epochs=1,则输出:

step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。

以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
朴素贝叶斯算法的python实现方法
Nov 18 Python
python 远程统计文件代码分享
May 14 Python
Python脚本暴力破解栅栏密码
Oct 19 Python
Python使用cx_Oracle模块操作Oracle数据库详解
May 07 Python
Python代码实现删除一个list里面重复元素的方法
Apr 02 Python
使用python爬取微博数据打造一颗“心”
Jun 28 Python
Python实用工具FuckIt.py介绍
Jul 02 Python
python字符串格式化方式解析
Oct 19 Python
Python拼接字符串的7种方式详解
Mar 19 Python
python按照list中字典的某key去重的示例代码
Oct 13 Python
python基于opencv批量生成验证码的示例
Apr 28 Python
如何使用Tkinter进行窗口的管理与设置
Jun 30 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
tensorflow-gpu安装的常见问题及解决方案
Jan 20 #Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 #Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 #Python
You might like
虫族 Zerg 热键控制
2020/03/14 星际争霸
php下判断数组中是否存在相同的值array_unique
2008/03/25 PHP
解析用PHP读写音频文件信息的详解(支持WMA和MP3)
2013/05/10 PHP
百度工程师讲PHP函数的实现原理及性能分析(二)
2015/05/13 PHP
PHP处理postfix邮件内容的方法
2015/06/16 PHP
PHP结合jQuery插件ajaxFileUpload实现异步上传文件实例
2020/08/17 PHP
Yii2框架实现数据库常用操作总结
2017/02/08 PHP
PHP基于rabbitmq操作类的生产者和消费者功能示例
2018/06/16 PHP
php实现的表单验证类完整示例
2019/08/13 PHP
由prototype_1.3.1进入javascript殿堂-类的初探
2006/11/06 Javascript
js prototype 格式化数字 By shawl.qiu
2007/04/02 Javascript
javascript appendChild,innerHTML,join性能比较代码
2009/08/29 Javascript
JS实现网站菜单拖拽移位效果的方法
2015/09/24 Javascript
全面了解JS中的匿名函数
2016/06/29 Javascript
简单实现jQuery多选框功能
2017/01/09 Javascript
jQuery实现的简单排序功能示例【冒泡排序】
2017/01/13 Javascript
从零开始做一个pagination分页组件
2017/03/15 Javascript
浅谈AngularJs 双向绑定原理(数据绑定机制)
2017/12/07 Javascript
在vue中获取微信支付code及code被占用问题的解决方法
2019/04/16 Javascript
vue-router重写push方法,解决相同路径跳转报错问题
2020/08/07 Javascript
Python生成随机数的方法
2014/01/14 Python
python中正则的使用指南
2016/12/04 Python
python中的迭代和可迭代对象代码示例
2017/12/27 Python
Python实现读取Properties配置文件的方法
2018/03/29 Python
python下载库的步骤方法
2019/10/12 Python
解决python便携版无法直接运行py文件的问题
2020/09/01 Python
Space NK美国站:英国高端美妆护肤商城
2017/05/22 全球购物
Kaufmann Mercantile官网:家居装饰、配件、户外及更多
2018/09/28 全球购物
局域网标准
2016/09/10 面试题
门卫班长岗位职责
2013/12/15 职场文书
加强机关作风建设心得体会
2014/10/22 职场文书
2014年行政部工作总结
2014/11/19 职场文书
2015年万圣节活动总结
2015/03/24 职场文书
怎样写工作总结啊!
2019/06/18 职场文书
MySQL之PXC集群搭建的方法步骤
2021/05/25 MySQL
Python实现拼音转换
2021/06/07 Python