对tensorflow中cifar-10文档的Read操作详解


Posted in Python onFebruary 10, 2020

前言

在tensorflow的官方文档中得卷积神经网络一章,有一个使用cifar-10图片数据集的实验,搭建卷积神经网络倒不难,但是那个cifar10_input文件着实让我费了一番心思。配合着官方文档也算看的七七八八,但是中间还是有一些不太明白,不明白的mark一下,这次记下一些已经明白的。

研究

cifar10_input.py文件的read操作,主要的就是下面的代码:

if not eval_data:
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
         for i in xrange(1, 6)]
  num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
 else:
  filenames = [os.path.join(data_dir, 'test_batch.bin')]
  num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
...
filename_queue = tf.train.string_input_producer(filenames)

...

label_bytes = 1 # 2 for CIFAR-100
 result.height = 32
 result.width = 32
 result.depth = 3
 image_bytes = result.height * result.width * result.depth
 # Every record consists of a label followed by the image, with a
 # fixed number of bytes for each.
 record_bytes = label_bytes + image_bytes

 # Read a record, getting filenames from the filename_queue. No
 # header or footer in the CIFAR-10 format, so we leave header_bytes
 # and footer_bytes at their default of 0.
 reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
 result.key, value = reader.read(filename_queue)

 ...

 if shuffle:
  images, label_batch = tf.train.shuffle_batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size,
    min_after_dequeue=min_queue_examples)
 else:
  images, label_batch = tf.train.batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size)

开始并不明白这段代码是用来干什么的,越看越糊涂,因为之前使用tensorflow最多也就是使用哪个tf.placeholder()这个操作,并没有使用tensorflow自带的读写方法来读写,所以上面的代码看的很费劲儿。不过我在官方文档的How-To这个document中看到了这个东西:

Batching

def read_my_file_format(filename_queue):
 reader = tf.SomeReader()
 key, record_string = reader.read(filename_queue)
 example, label = tf.some_decoder(record_string)
 processed_example = some_processing(example)
 return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
 filename_queue = tf.train.string_input_producer(
   filenames, num_epochs=num_epochs, shuffle=True)
 example, label = read_my_file_format(filename_queue)
 # min_after_dequeue defines how big a buffer we will randomly sample
 #  from -- bigger means better shuffling but slower start up and more
 #  memory used.
 # capacity must be larger than min_after_dequeue and the amount larger
 #  determines the maximum we will prefetch. Recommendation:
 #  min_after_dequeue + (num_threads + a small safety margin) * batch_size
 min_after_dequeue = 10000
 capacity = min_after_dequeue + 3 * batch_size
 example_batch, label_batch = tf.train.shuffle_batch(
   [example, label], batch_size=batch_size, capacity=capacity,
   min_after_dequeue=min_after_dequeue)
 return example_batch, label_batch

感觉豁然开朗,再研究一下其官方文档API就能大约明白期间意思。最有代表性的图示官方文档中也给出来了,虽然官方文档给的解释并不多。

对tensorflow中cifar-10文档的Read操作详解

API我就不一一解释了,我们下面通过实验来明白。

实验

首先在tensorflow路径下创建两个文件,分别命名为test.txt以及test2.txt,其内容分别是:

test.txt:

test line1
test line2
test line3
test line4
test line5
test line6

test2.txt:

test2 line1
test2 line2
test2 line3
test2 line4
test2 line5
test2 line6

然后再命令行里依次键入下面的命令:

import tensorflow as tf
filenames=['test.txt','test2.txt']
#创建如上图所示的filename_queue
filename_queue=tf.train.string_input_producer(filenames)
#选取的是每次读取一行的TextLineReader
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
#读取文件,也就是创建上图中的Reader
key,value=reader.read(filename_queue)
#读取batch文件,batch_size设置成1,为了方便看
bs=tf.train.batch([value],batch_size=1,num_threads=1,capacity=2)
sess=tf.Session() 
#非常关键,这个是连通各个queue图的关键          
tf.train.start_queue_runners(sess=sess)
#计算有reader的输出
b=reader.num_records_produced()

然后我们执行:

>>> sess.run(bs)
array(['test line1'], dtype=object)
>>> sess.run(b)
4
>>> sess.run(bs)
array(['test line2'], dtype=object)
>>> sess.run(b)
5
>>> sess.run(bs)
array(['test line3'], dtype=object)
>>> sess.run(bs)
array(['test line4'], dtype=object)
>>> sess.run(bs)
array(['test line5'], dtype=object)
>>> sess.run(bs)
array(['test line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test line1'], dtype=object)

我们发现,当batch_size设置成为1的时候,bs的输出是按照文件行数进行逐步打印的,原因是,我们选择的是单个Reader进行操作的,这个Reader先将test.txt文件读取,然后逐行读取并将读取的文本送到example queue(如上图)中,因为这里batch设置的是1,而且用到的是tf.train.batch()方法,中间没有shuffle,所以自然而然是按照顺序输出的,之后Reader再读取test2.txt。但是这里有一个疑惑,为什么reader.num_records_produced的第一个输出不是从1开始的,这点不太清楚。 另外,打印出filename_queue的size:

>>> sess.run(filename_queue.size())
32

发现filename_queue的size有32个之多!这点也不明白。。。

我们可以更改实验条件,将batch_size设置成2,会发现也是顺序的输出,而且每次输出为2行文本(和batch_size一样)

我们继续更改实验条件,将tf.train.batch方法换成tf.train.shuffle_batch方法,文本数据不变:

import tensorflow as tf
filenames=['test.txt','test2.txt']
filename_queue=tf.train.string_input_producer(filenames)
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
key,value=reader.read(filename_queue)
bs=tf.train.shuffle_batch([value],batch_size=1,num_threads=1,capacity=4,min_after_dequeue=2)
sess=tf.Session()           
tf.train.start_queue_runners(sess=sess)
b=reader.num_records_produced()

继续刚才的执行:

>>> sess.run(bs)
array(['test2 line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line6'], dtype=object)
>>> sess.run(bs)
array(['test2 line4'], dtype=object)
>>> sess.run(bs)
array(['test2 line3'], dtype=object)
>>> sess.run(bs)
array(['test line1'], dtype=object)
>>> sess.run(bs)
array(['test line2'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test line4'], dtype=object)
>>> sess.run(bs)
array(['test line5'], dtype=object)
>>> sess.run(bs)
array(['test2 line1'], dtype=object)
>>> sess.run(bs)
array(['test line3'], dtype=object)

我们发现的是,使用了shuffle操作之后,明显的bs的输出变得不一样了,变得没有规则,然后我们看filename_queue的size:

>>> sess.run(filename_queue.size())
32

发现也是32,由此估计是tensorflow会根据文件大小默认filename_queue的长度。 注意这里面的capacity=4,min_after_dequeue=2这些个命令,capacity指的是example queue的最大长度, 而min_after_dequeue是指在出队列之后,example queue最少要保留的元素个数,为什么需要这个,其实是为了混合的更显著。也正是有这两个元素,让shuffle变得可能。

到这里基本上大概的思路能明白,但是上面的实验都是对于单个的Reader,和上一节的图不太一致,根据官网教程,为了使用多个Reader,我们可以这样:

import tensorflow as tf
filenames=['test.txt','test2.txt']
filename_queue=tf.train.string_input_producer(filenames)
reader=tf.TextLineReader()
init=tf.initialize_all_variables()
key_list,value_list=[reader.read(filename_queue) for _ in range(2)]
bs2=tf.train.shuffle_batch_join([value_list],batch_size=1,capacity=4,min_after_dequeue=2)
sess=tf.Session()       
sess.run(init)    
tf.train.start_queue_runners(sess=sess)

运行的结果如下:

>>> sess.run(bs2)
[array(['test2.txt:2'], dtype=object), array(['test2 line2'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:5'], dtype=object), array(['test2 line5'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:6'], dtype=object), array(['test2 line6'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:4'], dtype=object), array(['test2 line4'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:3'], dtype=object), array(['test2 line3'], dtype=object)]
>>> sess.run(bs2)
[array(['test2.txt:1'], dtype=object), array(['test2 line1'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:4'], dtype=object), array(['test line4'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:3'], dtype=object), array(['test line3'], dtype=object)]
>>> sess.run(bs2)
[array(['test.txt:2'], dtype=object), array(['test line2'], dtype=object)]

以上这篇对tensorflow中cifar-10文档的Read操作详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现多线程的方式及多条命令并发执行
Jun 07 Python
Python实现统计文本文件字数的方法
May 05 Python
将TensorFlow的模型网络导出为单个文件的方法
Apr 23 Python
用Python3创建httpServer的简单方法
Jun 04 Python
Python中的引用知识点总结
May 20 Python
简单介绍python封装的基本知识
Aug 10 Python
python使用梯度下降和牛顿法寻找Rosenbrock函数最小值实例
Apr 02 Python
Python Json数据文件操作原理解析
May 09 Python
浅谈Python爬虫原理与数据抓取
Jul 21 Python
Python绘图之二维图与三维图详解
Aug 04 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
Python中Cookies导出某站用户数据的方法
May 17 Python
基于Tensorflow:CPU性能分析
Feb 10 #Python
python sorted函数原理解析及练习
Feb 10 #Python
python pprint模块中print()和pprint()两者的区别
Feb 10 #Python
python yield和Generator函数用法详解
Feb 10 #Python
Tensorflow 卷积的梯度反向传播过程
Feb 10 #Python
tensorflow 实现自定义梯度反向传播代码
Feb 10 #Python
用Python做一个久坐提醒小助手的示例代码
Feb 10 #Python
You might like
php REMOTE_ADDR之获取访客IP的代码
2008/04/22 PHP
新浪微博API开发简介之用户授权(PHP基础篇)
2011/09/25 PHP
php 批量替换html标签的实例代码
2013/11/26 PHP
Laravel实现用户注册和登录
2015/01/23 PHP
PHP使用PDO操作sqlite数据库应用案例
2019/03/07 PHP
PHP 裁剪图片
2021/03/09 PHP
仿服务器端脚本方式的JS模板实现方法
2007/04/27 Javascript
获取dom元素那些讨厌的位置封装代码
2010/06/23 Javascript
JavaScript调用浏览器打印功能实例分析
2015/07/17 Javascript
Bootstrap每天必学之日期控制
2016/03/07 Javascript
微信小程序 122100版本更新问题解决方案
2016/12/22 Javascript
微信小程序实现带参数的分享功能(两种方法)
2019/05/17 Javascript
python实现查找excel里某一列重复数据并且剔除后打印的方法
2015/05/26 Python
Python如何抓取天猫商品详细信息及交易记录
2018/02/23 Python
详谈python3中用for循环删除列表中元素的坑
2018/04/19 Python
python寻找list中最大值、最小值并返回其所在位置的方法
2018/06/27 Python
Python3 关于pycharm自动导入包快捷设置的方法
2019/01/16 Python
Python 虚拟空间的使用代码详解
2019/06/10 Python
基于python解线性矩阵方程(numpy中的matrix类)
2019/10/21 Python
Python递归函数特点及原理解析
2020/03/04 Python
tensorflow pb to tflite 精度下降详解
2020/05/25 Python
CSS3实现银灰色动画效果的导航菜单代码
2015/09/01 HTML / CSS
css3类选择器之结合元素选择器和多类选择器用法
2017/03/09 HTML / CSS
Servlet面试题库
2015/07/18 面试题
硕士研究生自我鉴定
2013/11/08 职场文书
刚毕业大学生自荐信范文
2014/02/20 职场文书
公司承诺书格式
2014/05/21 职场文书
学校总务处领导班子民主生活会对照检查材料思想汇报
2014/09/27 职场文书
党的群众路线调研报告
2014/11/03 职场文书
工作简历自我评价
2015/03/11 职场文书
首都博物馆观后感
2015/06/05 职场文书
田径运动会广播稿
2015/08/19 职场文书
2019财务毕业实习报告
2019/06/27 职场文书
Redis Cluster 字段模糊匹配及删除
2021/05/27 Redis
OpenCV中resize函数插值算法的实现过程(五种)
2021/06/05 Python
js前端图片加载异常兜底方案
2022/06/21 Javascript