tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现抓取页面上链接的简单爬虫分享
Jan 21 Python
Python实现将文本生成二维码的方法示例
Jul 18 Python
python实现简单聊天应用 python群聊和点对点均实现
Sep 14 Python
Python实现的摇骰子猜大小功能小游戏示例
Dec 18 Python
解决安装tensorflow遇到无法卸载numpy 1.8.0rc1的问题
Jun 13 Python
浅谈python str.format与制表符\t关于中文对齐的细节问题
Jan 14 Python
Python正则表达式和re库知识点总结
Feb 11 Python
django模板加载静态文件的方法步骤
Mar 01 Python
Python 实现将numpy中的nan和inf,nan替换成对应的均值
Jun 08 Python
Python基于pyjnius库实现访问java类
Jul 31 Python
python 获取剪切板内容的两种方法
Nov 28 Python
Python+tkinter实现高清图片保存
Mar 13 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
smarty巧妙处理iframe中内容页的代码
2012/03/07 PHP
CodeIgniter中使用cookie的三种方式详解
2014/07/18 PHP
Laravel中扩展Memcached缓存驱动实现使用阿里云OCS缓存
2015/02/10 PHP
IE innerHTML,outerHTML所引起的问题
2009/06/04 Javascript
用js写了一个类似php的print_r输出换行功能
2013/02/18 Javascript
浅析JavaScript中的CSS属性及命名规范
2013/11/28 Javascript
jQuery表单域选择器用法分析
2015/02/10 Javascript
微信小程序page的生命周期和音频播放及监听实例详解
2017/04/07 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
vue + webpack如何绕过QQ音乐接口对host的验证详解
2018/07/01 Javascript
JavaScript计算正方形面积
2019/11/26 Javascript
[02:46]完美世界DOTA2联赛PWL DAY4集锦
2020/11/03 DOTA
[35:27]完美世界DOTA2联赛循环赛 GXR vs FTD BO2第二场 10.29
2020/10/29 DOTA
Python文件操作类操作实例详解
2014/07/11 Python
Python中的exec、eval使用实例
2014/09/23 Python
Python爬虫框架Scrapy实战之批量抓取招聘信息
2015/08/07 Python
python 文件操作删除某行的实例
2017/09/04 Python
python递归函数绘制分形树的方法
2018/06/22 Python
python自动发邮件总结及实例说明【推荐】
2019/05/31 Python
Python3基本输入与输出操作实例分析
2020/02/14 Python
python词云库wordCloud使用方法详解(解决中文乱码)
2020/02/17 Python
2019年c语言经典面试题目
2016/08/17 面试题
SQL数据库笔试题
2016/03/08 面试题
建筑专业毕业生推荐信
2013/11/21 职场文书
护士自我鉴定怎么写
2014/02/07 职场文书
公司年会主持词
2014/03/22 职场文书
国贸专业毕业求职信
2014/06/11 职场文书
中秋节活动总结
2014/08/29 职场文书
四风问题班子对照检查材料
2014/09/27 职场文书
先进人物事迹材料
2014/12/29 职场文书
爱鸟护鸟的宣传语
2015/07/13 职场文书
2015年学校少先队工作总结
2015/07/20 职场文书
2016年党员承诺书范文
2016/03/24 职场文书
解决redis sentinel 频繁主备切换的问题
2021/04/12 Redis
golang switch语句的灵活写法介绍
2021/05/06 Golang
nginx lua 操作 mysql
2022/05/15 Servers