Tensorflow分批量读取数据教程


Posted in Python onFebruary 07, 2020

之前的博客里使用tf读取数据都是每次fetch一条记录,实际上大部分时候需要fetch到一个batch的小批量数据,在tf中这一操作的明显变化就是tensor的rank发生了变化,我目前使用的人脸数据集是灰度图像,因此大小是92*112的,所以最开始fetch拿到的图像数据集经过reshape之后就是一个rank为2的tensor,大小是92*112的(如果考虑通道,也可以reshape为rank为3的,即92*112*1)。如果加入batch,比如batch大小为5,那么拿到的tensor的rank就变成了3,大小为5*92*112。

下面规则化的写一下读取数据的一般流程,按照官网的实例,一般把读取数据拆分成两个大部分,一个是函数专门负责读取数据和解码数据,一个函数则负责生产batch。

import tensorflow as tf

def read_data(fileNameQue):

  reader = tf.TFRecordReader()
  key, value = reader.read(fileNameQue)
  features = tf.parse_single_example(value, features={'label': tf.FixedLenFeature([], tf.int64),
                            'img': tf.FixedLenFeature([], tf.string),})
  img = tf.decode_raw(features["img"], tf.uint8)
  img = tf.reshape(img, [92,112]) # 恢复图像原始大小
  label = tf.cast(features["label"], tf.int32)

  return img, label

def batch_input(filename, batchSize):

  fileNameQue = tf.train.string_input_producer([filename], shuffle=True)
  img, label = read_data(fileNameQue) # fetch图像和label
  min_after_dequeue = 1000
  capacity = min_after_dequeue+3*batchSize
  # 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
  exampleBatch,labelBatch = tf.train.shuffle_batch([img, label],batch_size=batchSize, capacity=capacity,
                           min_after_dequeue=min_after_dequeue)
  return exampleBatch,labelBatch

if __name__ == "__main__":

  init = tf.initialize_all_variables()
  exampleBatch, labelBatch = batch_input("./data/faceTF.tfrecords", batchSize=10)

  with tf.Session() as sess:

    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(100):
      example, label = sess.run([exampleBatch, labelBatch])
      print(example.shape)

    coord.request_stop()
    coord.join(threads)

读取数据和解码数据与之前基本相同,针对不同格式数据集使用不同阅读器和解码器即可,后面是产生batch,核心是tf.train.shuffle_batch这个函数,它相当于一个蓄水池的功能,第一个参数代表蓄水池的入水口,也就是逐个读取到的记录,batch_size自然就是batch的大小了,capacity是蓄水池的容量,表示能容纳多少个样本,min_after_dequeue是指出队操作后还可以供随机采样出批量数据的样本池大小,显然,capacity要大于min_after_dequeue,官网推荐:min_after_dequeue + (num_threads + a small safety margin) * batch_size,还有一个参数就是num_threads,表示所用线程数目。

min_after_dequeue这个值越大,随机采样的效果越好,但是消耗的内存也越大。

以上这篇Tensorflow分批量读取数据教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现CET查分的方法
Mar 10 Python
从局部变量和全局变量开始全面解析Python中变量的作用域
Jun 16 Python
利用Python批量压缩png方法实例(支持过滤个别文件与文件夹)
Jul 30 Python
Python数据结构与算法之图的基本实现及迭代器实例详解
Dec 12 Python
如何使用Python的Requests包实现模拟登陆
Apr 27 Python
使用matplotlib画散点图的方法
May 25 Python
详解python爬虫系列之初识爬虫
Apr 06 Python
Django中自定义admin Xadmin的实现代码
Aug 09 Python
Python实现打印实心和空心菱形
Nov 23 Python
关于numpy.where()函数 返回值的解释
Dec 06 Python
Python绘制K线图之可视化神器pyecharts的使用
Mar 02 Python
bat批处理之字符串操作的实现
Mar 16 Python
python统计字符的个数代码实例
Feb 07 #Python
Python使用PyQt5/PySide2编写一个极简的音乐播放器功能
Feb 07 #Python
Tensorflow tf.dynamic_partition矩阵拆分示例(Python3)
Feb 07 #Python
Python reshape的用法及多个二维数组合并为三维数组的实例
Feb 07 #Python
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
Feb 07 #Python
Tensorflow进行多维矩阵的拆分与拼接实例
Feb 07 #Python
Tensorflow训练模型越来越慢的2种解决方案
Feb 07 #Python
You might like
PHP读取文件并可支持远程文件的代码分享
2012/10/03 PHP
hadoop常见错误以及处理方法详解
2013/06/19 PHP
Laravel5中contracts详解
2015/03/02 PHP
php注册和登录界面的实现案例(推荐)
2016/10/24 PHP
Ubuntu中支持PHP5与PHP7双版本的简单实现
2018/08/19 PHP
Laravel 修改默认日志文件名称和位置的例子
2019/10/17 PHP
laravel5.5安装jwt-auth 生成token令牌的示例
2019/10/24 PHP
javascript基础第一章 JavaScript与用户端
2010/07/22 Javascript
document.documentElement和document.body区别介绍
2013/09/16 Javascript
js网页右下角提示框实例
2014/10/14 Javascript
jQuery Validate初步体验(二)
2015/12/12 Javascript
jquery mobile开发常见问题分析
2016/01/21 Javascript
理解javascript中的with关键字
2016/02/15 Javascript
javascript特殊日历控件分享
2016/03/07 Javascript
javascript截图 jQuery插件imgAreaSelect使用详解
2016/05/04 Javascript
JS检测window.open打开的窗口是否关闭
2017/06/25 Javascript
浅谈react 同构之样式直出
2017/11/07 Javascript
使用webpack打包koa2 框架app
2018/02/02 Javascript
vue实现pdf导出解决生成canvas模糊等问题(推荐)
2018/10/18 Javascript
python根据路径导入模块的方法
2014/09/30 Python
Python网络爬虫之爬取微博热搜
2019/04/18 Python
Python matplotlib画图与中文设置操作实例分析
2019/04/23 Python
详解DeBug Python神级工具PySnooper
2019/07/03 Python
解决os.path.isdir() 判断文件夹却返回false的问题
2019/11/29 Python
Mytheresa英国官网:拥有160多个奢侈品品牌
2016/10/09 全球购物
美国性感女装网站:bebe
2017/03/04 全球购物
美国山地自行车、露营、户外装备和服装购物网站:Aventuron
2018/05/05 全球购物
PHP如何自定义函数
2016/09/16 面试题
Shell编程面试题
2012/05/30 面试题
挂牌仪式主持词
2014/03/20 职场文书
难忘的一天教学反思
2014/04/30 职场文书
班主任工作经验交流材料
2014/05/13 职场文书
派出所班子党的群众路线对照检查材料思想汇报
2014/10/01 职场文书
个人原因辞职信模板
2015/05/13 职场文书
基层党建工作简报
2015/07/21 职场文书
JS 4个超级实用的小技巧 提升开发效率
2021/10/05 Javascript