tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python关键字and和or用法实例
May 28 Python
python DataFrame 修改列的顺序实例
Apr 10 Python
Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能示例
May 16 Python
Django如何自定义分页
Sep 25 Python
python-tkinter之按钮的使用,开关方法
Jun 11 Python
python实现列表的排序方法分享
Jul 01 Python
django如何通过类视图使用装饰器
Jul 24 Python
Python使用字典实现的简单记事本功能示例
Aug 15 Python
python3实现elasticsearch批量更新数据
Dec 03 Python
使用python执行shell脚本 并动态传参 及subprocess的使用详解
Mar 06 Python
基于Python组装jmx并调用JMeter实现压力测试
Nov 03 Python
Python爬虫之Selenium实现关闭浏览器
Dec 04 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
搜索和替换文件或目录的一个好类--很实用
2006/10/09 PHP
PHP程序员编程注意事项
2008/04/10 PHP
php file_put_contents()功能函数(集成了fopen、fwrite、fclose)
2011/05/24 PHP
简单的php中文转拼音的实现代码
2014/02/11 PHP
php file_get_contents取文件中数组元素的方法
2017/04/01 PHP
PHP支付宝当面付2.0代码
2018/12/21 PHP
TFDN图片播放器 不错自动播放
2006/10/03 Javascript
Javascript生成json的函数代码(可以用php的json_decode解码)
2012/06/11 Javascript
javascript中全局对象的parseInt()方法使用介绍
2013/12/19 Javascript
JavaScript中的迭代器和生成器详解
2014/10/29 Javascript
JavaScript实现控制打开文件另存为对话框的方法
2015/04/17 Javascript
js点击列表文字对应该行显示背景颜色的实现代码
2015/08/05 Javascript
jQuery height()、innerHeight()、outerHeight()函数的区别详解
2016/05/23 Javascript
EasyUI在Panel上动态添加LinkButton按钮
2017/08/11 Javascript
JavaScript中Object值合并方法详解
2017/12/22 Javascript
Vue使用json-server进行后端数据模拟功能
2018/04/17 Javascript
Hexo已经看腻了,来手把手教你使用VuePress搭建个人博客
2018/04/26 Javascript
通过jquery的ajax请求本地的json文件方法
2018/08/08 jQuery
Angular2实现的秒表及改良版示例
2019/05/10 Javascript
VUEX 数据持久化,刷新后重新获取的例子
2019/11/12 Javascript
Vue中多元素过渡特效的解决方案
2020/02/05 Javascript
[02:16]完美世界DOTA2联赛PWL S3 集锦第三期
2020/12/21 DOTA
python发送邮件接收邮件示例分享
2014/01/21 Python
Python基于pillow判断图片完整性的方法
2016/09/18 Python
Python实现复杂对象转JSON的方法示例
2017/06/22 Python
使用Django Form解决表单数据无法动态刷新的两种方法
2017/07/14 Python
Python基于回溯法子集树模板实现图的遍历功能示例
2017/09/05 Python
django rest framework 数据的查找、过滤、排序的示例
2018/06/25 Python
Python 脚本拉取 Docker 镜像问题
2019/11/10 Python
python pygame实现滚动横版射击游戏城市之战
2019/11/25 Python
python使用html2text库实现从HTML转markdown的方法详解
2020/02/21 Python
Crocs美国官方网站:卡骆驰洞洞鞋
2017/08/04 全球购物
澳大利亚相机之家:Camera House
2017/11/30 全球购物
社会实践单位意见
2015/06/05 职场文书
关于感恩的素材句子(38句)
2019/11/11 职场文书
python tqdm用法及实例详解
2021/06/16 Python