tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模拟登陆Tom邮箱示例分享
Jan 13 Python
Python最长公共子串算法实例
Mar 07 Python
初步介绍Python中的pydoc模块和distutils模块
Apr 13 Python
浅谈python可视化包Bokeh
Feb 07 Python
django Serializer序列化使用方法详解
Oct 16 Python
树莓派动作捕捉抓拍存储图像脚本
Jun 22 Python
python用线性回归预测股票价格的实现代码
Sep 04 Python
Python爬虫实现使用beautifulSoup4爬取名言网功能案例
Sep 15 Python
PyCharm License Activation激活码失效问题的解决方法(图文详解)
Mar 12 Python
python函数map()和partial()的知识点总结
May 26 Python
Python 获取异常(Exception)信息的几种方法
Dec 29 Python
请求模块urllib之PYTHON爬虫的基本使用
Apr 08 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
PHP中实现图片的锐化
2006/10/09 PHP
屏蔽浏览器缓存另类方法
2006/10/09 PHP
php中$_SERVER[PHP_SELF] 和 $_SERVER[SCRIPT_NAME]之间的区别
2009/09/05 PHP
PHP数组及条件,循环语句学习
2012/11/11 PHP
php弹出对话框实现重定向代码
2014/01/23 PHP
php简单判断两个字符串是否相等的方法
2015/07/13 PHP
PHP正则表达式入门教程(推荐)
2016/05/18 PHP
PHP引用返回用法示例
2016/05/28 PHP
yii 框架实现按天,月,年,自定义时间段统计数据的方法分析
2020/04/04 PHP
js wmp操作代码小结(音乐连播功能)
2008/11/08 Javascript
基于jquery1.4.2的仿flash超炫焦点图播放效果
2010/04/20 Javascript
JavaScript实现在标题栏上显示当前日期的方法
2015/03/19 Javascript
javascript每日必学之条件分支
2016/02/17 Javascript
jQuery数组处理函数整理
2016/08/03 Javascript
第一次动手实现bootstrap table分页效果
2016/09/22 Javascript
js 动态生成html 触发事件传参字符转义的实例
2017/02/14 Javascript
JavaScript组件开发之输入框加候选框
2017/03/10 Javascript
JavaScript中的普通函数和箭头函数的区别和用法详解
2017/03/21 Javascript
Angular 2父子组件数据传递之@Input和@Output详解 (上)
2017/07/05 Javascript
easyui datagrid 表格中操作栏 按钮图标不显示的解决方法
2017/07/27 Javascript
Node.js模拟发起http请求从异步转同步的5种用法
2018/09/26 Javascript
详解关于html,css,js三者的加载顺序问题
2019/04/10 Javascript
vue项目中锚点定位替代方式
2019/11/13 Javascript
vue使用axios实现excel文件下载的功能
2020/07/16 Javascript
vue实现点击按钮“查看详情”弹窗展示详情列表操作
2020/09/09 Javascript
浅析Python中else语句块的使用技巧
2016/06/16 Python
cmd运行python文件时对结果进行保存的方法
2018/05/16 Python
python数据处理 根据颜色对图片进行分类的方法
2018/12/08 Python
基于python2.7实现图形密码生成器的实例代码
2019/11/05 Python
python实现矩阵和array数组之间的转换
2019/11/29 Python
企业军训感想
2014/02/07 职场文书
2014年党课学习心得体会
2014/07/08 职场文书
2014法院干警廉洁警示教育思想汇报
2014/09/13 职场文书
mysql的MVCC多版本并发控制的实现
2021/04/14 MySQL
go原生库的中bytes.Buffer用法
2021/04/25 Golang
pytorch交叉熵损失函数的weight参数的使用
2021/05/24 Python