tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中__slots__用法实例
Jun 04 Python
使用pandas对矢量化数据进行替换处理的方法
Apr 11 Python
Python 字符串转换为整形和浮点类型的方法
Jul 17 Python
Python实现获取汉字偏旁部首的方法示例【测试可用】
Dec 18 Python
Selenium+Python 自动化操控登录界面实例(有简单验证码图片校验)
Jun 28 Python
关于Python 的简单栅格图像边界提取方法
Jul 05 Python
在python中实现同行输入/接收多个数据的示例
Jul 20 Python
ubuntu 18.04 安装opencv3.4.5的教程(图解)
Nov 04 Python
用openCV和Python 实现图片对比,并标识出不同点的方式
Dec 19 Python
基于Python的身份证验证识别和数据处理详解
Nov 14 Python
Python中使用subprocess库创建附加进程
May 11 Python
python入门学习关于for else的特殊特性讲解
Nov 20 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
php在字符串中查找另一个字符串
2008/11/19 PHP
PHP计数器的实现代码
2013/06/08 PHP
php中文验证码实现方法
2015/06/18 PHP
PHP抓取及分析网页的方法详解
2016/04/26 PHP
PHP常见数组排序方法小结
2018/08/20 PHP
Laravel 对某一列进行筛选然后求和sum()的例子
2019/10/10 PHP
用JavaScript玩转游戏物理(一)运动学模拟与粒子系统
2010/06/19 Javascript
JQuery.Ajax之错误调试帮助信息介绍
2013/07/04 Javascript
jQuery事件用法实例汇总
2014/08/29 Javascript
Angular中的Promise对象($q介绍)
2015/03/03 Javascript
Bootstrap每天必学之栅格系统(布局)
2015/11/25 Javascript
AngularJS 过滤器(自带和自建)详解
2016/09/19 Javascript
require.js 加载 vue组件 r.js 合并压缩的实例
2016/10/14 Javascript
vue导出html、word和pdf的实现代码
2018/07/31 Javascript
使用mpvue搭建一个初始小程序及项目配置方法
2018/12/03 Javascript
解决layui-table单元格设置为百分比在ie8下不能自适应的问题
2019/09/28 Javascript
jQuery轮播图功能制作方法详解
2019/12/03 jQuery
TypeScript的安装、使用、自动编译的实现
2020/04/10 Javascript
python爬虫框架scrapy实战之爬取京东商城进阶篇
2017/04/24 Python
Python线程同步的实现代码
2018/10/03 Python
Python中的函数式编程:不可变的数据结构
2018/10/08 Python
Django中数据库的数据关系:一对一,一对多,多对多
2018/10/21 Python
快速解决pyqt5窗体关闭后子线程不同时退出的问题
2019/06/19 Python
python 协程 gevent原理与用法分析
2019/11/22 Python
更新升级python和pip版本后不生效的问题解决
2020/04/17 Python
python drf各类组件的用法和作用
2021/01/12 Python
css3背景图片透明叠加属性cross-fade简介及用法实例
2013/01/08 HTML / CSS
奥兰多迪士尼门票折扣:Undercover Tourist
2018/07/09 全球购物
MAC Cosmetics官方网站:魅可专业艺术彩妆
2019/04/10 全球购物
六查六看剖析材料
2014/02/15 职场文书
精神文明单位申报材料
2014/05/02 职场文书
水利水电建筑施工应届生求职信
2014/07/04 职场文书
2014年纪委工作总结
2014/12/05 职场文书
2015年防汛工作总结
2015/05/15 职场文书
微信小程序结合ThinkPHP5授权登陆后获取手机号
2021/11/23 PHP
gtx1650怎么样 gtx1650显卡相当于什么级别
2022/04/08 数码科技