tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ValueError: invalid literal for int() with base 10 实用解决方法
Jun 21 Python
python对象及面向对象技术详解
Jul 19 Python
Python序列操作之进阶篇
Dec 08 Python
Python微信库:itchat的用法详解
Aug 14 Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 Python
Python实现扣除个人税后的工资计算器示例
Mar 26 Python
如何用Python来搭建一个简单的推荐系统
Aug 07 Python
python判断自身是否正在运行的方法
Aug 08 Python
python图形绘制奥运五环实例讲解
Sep 14 Python
Python 网络编程之UDP发送接收数据功能示例【基于socket套接字】
Oct 11 Python
使用Python实现分别输出每个数组
Dec 06 Python
python基于turtle绘制几何图形
Jun 15 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
PHP原理之异常机制深入分析
2010/08/08 PHP
PHP 八种基本的数据类型小结
2011/06/01 PHP
解析php中const与define的应用区别
2013/06/18 PHP
php实现递归的三种基本方式
2020/07/04 PHP
YII Framework框架教程之使用YIIC快速创建YII应用详解
2016/03/15 PHP
firefox和IE系列的相关区别整理 以备后用
2009/12/28 Javascript
js实现页面转发功能示例代码
2013/08/05 Javascript
JavaScript调用客户端的可执行文件(示例代码)
2013/11/28 Javascript
js实现仿爱微网两级导航菜单效果代码
2015/08/31 Javascript
jQuery实现的登录浮动框效果代码
2015/09/26 Javascript
原生js实现autocomplete插件
2016/04/14 Javascript
js canvas实现擦除动画
2016/07/16 Javascript
浅谈js函数中的实例对象、类对象、局部变量(局部函数)
2016/11/20 Javascript
Ionic+AngularJS实现登录和注册带验证功能
2017/02/09 Javascript
jQuery+vue.js实现的九宫格拼图游戏完整实例【附源码下载】
2017/09/12 jQuery
基于vue的短信验证码倒计时demo
2017/09/13 Javascript
javascript中this的用法实践分析
2019/07/29 Javascript
[04:03]DOTA2肉山黑名单梦之声 风暴之灵中文配音鉴赏
2013/07/03 DOTA
[05:09]2016国际邀请赛中国区预选赛淘汰赛首日精彩回顾
2016/06/29 DOTA
Python错误: SyntaxError: Non-ASCII character解决办法
2017/06/08 Python
Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例
2017/11/23 Python
python如何实现int函数的方法示例
2018/02/19 Python
Python中py文件引用另一个py文件变量的方法
2018/04/29 Python
解决seaborn在pycharm中绘图不出图的问题
2018/05/24 Python
python实现zabbix发送短信脚本
2018/09/17 Python
python中logging模块的一些简单用法的使用
2019/02/22 Python
python 判断字符串中是否含有汉字或非汉字的实例
2019/07/15 Python
python实现对变位词的判断方法
2020/04/05 Python
win7上tensorflow2.2.0安装成功 引用DLL load failed时找不到指定模块 tensorflow has no attribute xxx 解决方法
2020/05/20 Python
python线性插值解析
2020/07/05 Python
Interrail法国:乘火车探索欧洲,最受欢迎的欧洲铁路通票
2019/08/27 全球购物
2015大学迎新标语
2015/07/16 职场文书
公司劳动纪律管理制度
2015/08/04 职场文书
新员工入职感想
2015/08/07 职场文书
python文本处理的方案(结巴分词并去除符号)
2021/05/26 Python
flex弹性布局详解
2022/03/20 HTML / CSS