tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的类与对象之描述符详解
Mar 27 Python
简单介绍Python2.x版本中的cmp()方法的使用
May 20 Python
如何将python中的List转化成dictionary
Aug 15 Python
利用Python自动监控网站并发送邮件告警的方法
Aug 24 Python
浅谈python内置变量-reversed(seq)
Jun 21 Python
对python的unittest架构公共参数token提取方法详解
Dec 17 Python
python实现的读取网页并分词功能示例
Oct 29 Python
Django框架教程之中间件MiddleWare浅析
Dec 29 Python
Python暴力破解Mysql数据的示例
Nov 09 Python
Python高并发和多线程有什么关系
Nov 14 Python
Python类class参数self原理解析
Nov 19 Python
Python pandas求方差和标准差的方法实例
Aug 04 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
phpmyadmin的#1251问题
2006/11/25 PHP
用php实现的下载css文件中的图片的代码
2010/02/08 PHP
PHP5 的对象赋值机制介绍
2011/08/02 PHP
php 短链接算法收集与分析
2011/12/30 PHP
web站点获取用户IP的安全方法 HTTP_X_FORWARDED_FOR检验
2013/06/01 PHP
文件上传之SWFUpload插件(代码)
2015/07/30 PHP
ie 处理 gif动画 的onload 事件的一个 bug
2007/04/12 Javascript
js+CSS 图片等比缩小并垂直居中实现代码
2008/12/01 Javascript
如何确保JavaScript的执行顺序 之jQuery.html深度分析
2011/03/03 Javascript
javascript:history.go()和History.back()的区别及应用
2012/11/25 Javascript
Javascript图像处理思路及实现代码
2012/12/25 Javascript
jquery教程限制文本框只能输入数字和小数点示例分享
2014/01/13 Javascript
谈谈我对JavaScript DOM事件的理解
2015/12/18 Javascript
简单的JS时钟实例讲解
2016/01/13 Javascript
canvas实现图像布局填充功能
2017/02/06 Javascript
详解webpack 多入口配置
2017/06/16 Javascript
详解node+express+ejs+bootstrap构建项目
2017/09/27 Javascript
ES6中的Promise代码详解
2017/10/09 Javascript
NodeJS爬虫实例之糗事百科
2017/12/14 NodeJs
解决Angular.js中使用Swiper插件不能滑动的问题
2018/02/26 Javascript
详解Angular6 热加载配置方案
2018/08/18 Javascript
vue-axios同时请求多个接口 等所有接口全部加载完成再处理操作
2020/11/09 Javascript
[52:52]DOTA2上海特级锦标赛C组资格赛#1 OG VS LGD第三局
2016/02/27 DOTA
用Python登录Gmail并发送Gmail邮件的教程
2015/04/17 Python
Python读取mat文件,并保存为pickle格式的方法
2018/10/23 Python
TensorFlow tf.nn.max_pool实现池化操作方式
2020/01/04 Python
Django自定义YamlField实现过程解析
2020/11/11 Python
HTML5在微信内置浏览器下右上角菜单的调整字体导致页面显示错乱的问题
2021/01/19 HTML / CSS
出国留学自荐信
2013/10/25 职场文书
正规的求职信范文分享
2013/12/11 职场文书
闭幕式主持词
2014/04/02 职场文书
初级党校心得体会
2014/09/11 职场文书
个人自我剖析材料
2014/09/30 职场文书
财务部岗位职责
2015/02/03 职场文书
2019秋季运动会口号
2019/06/25 职场文书
Python中OpenCV实现查找轮廓的实例
2021/06/08 Python