tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用wxPython开发的一个简易笔记本程序实例
Feb 08 Python
基于python的七种经典排序算法(推荐)
Dec 08 Python
Python 含参构造函数实例详解
May 25 Python
Python 列表理解及使用方法
Oct 27 Python
Pycharm设置去除显示的波浪线方法
Oct 28 Python
python用列表生成式写嵌套循环的方法
Nov 08 Python
对python中的控制条件、循环和跳出详解
Jun 24 Python
python3使用腾讯企业邮箱发送邮件的实例
Jun 28 Python
python 遗传算法求函数极值的实现代码
Feb 11 Python
python实现PDF中表格转化为Excel的方法
Jun 16 Python
Python手动或自动协程操作方法解析
Jun 22 Python
Django Form常用功能及代码示例
Oct 13 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
全国FM电台频率大全 - 14 江西省
2020/03/11 无线电
wordpress之wp-settings.php
2007/08/17 PHP
php中判断字符串是否全是中文或含有中文的实现代码
2011/09/16 PHP
PHP三元运算符的结合性介绍
2012/01/10 PHP
在yii中新增一个用户验证的方法详解
2013/06/20 PHP
CI框架常用函数封装实例
2016/11/21 PHP
js身份证验证超强脚本
2008/10/26 Javascript
Javascript 遍历对象中的子对象
2009/07/03 Javascript
javascript题目,重写函数让其无限相加
2012/02/15 Javascript
javascript时区函数介绍
2012/09/14 Javascript
基于JavaScript 类的使用详解
2013/05/07 Javascript
node.js中的fs.existsSync方法使用说明
2014/12/17 Javascript
js实现使用鼠标拖拽切换图片的方法
2015/05/04 Javascript
AngularJS整合Springmvc、Spring、Mybatis搭建开发环境
2016/02/25 Javascript
JavaScript中错误正确处理方式小结你用对了吗
2017/10/10 Javascript
详解Vue SPA项目优化小记
2018/07/03 Javascript
jQuery 操作 HTML 元素和属性的方法
2018/11/12 jQuery
详解mpvue开发微信小程序基础知识
2019/09/23 Javascript
ant design实现圈选功能
2019/12/17 Javascript
vue计算属性+vue中class与style绑定(推荐)
2020/03/30 Javascript
vue中实现弹出层动画效果的示例代码
2020/09/25 Javascript
[46:32]Fnatic vs OG 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python类继承与子类实例初始化用法分析
2015/04/17 Python
全面了解python字符串和字典
2016/07/07 Python
python查看模块安装位置的方法
2018/10/16 Python
Django Rest framework权限的详细用法
2019/07/25 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
2020/06/18 Python
一款纯css3实现的tab选项卡的实列教程
2014/12/11 HTML / CSS
html5 分层屏幕适配的方法
2018/03/16 HTML / CSS
一套.net面试题及答案
2016/11/02 面试题
大四学生思想汇报
2014/01/13 职场文书
土地租赁意向书
2014/07/30 职场文书
户籍证明格式
2014/09/15 职场文书
用golang如何替换某个文件中的字符串
2021/04/25 Golang
Django 实现jwt认证的示例
2021/04/30 Python
go语言使用Casbin实现角色的权限控制
2021/06/26 Golang