tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django中进行用户注册和邮箱验证的方法
May 09 Python
Python实现曲线点抽稀算法的示例
Oct 12 Python
Python网络编程之TCP套接字简单用法示例
Apr 09 Python
Django unittest 设置跳过某些case的方法
Dec 26 Python
对python过滤器和lambda函数的用法详解
Jan 21 Python
Python Pandas 获取列匹配特定值的行的索引问题
Jul 01 Python
python中web框架的自定义创建
Sep 08 Python
django序列化serializers过程解析
Dec 14 Python
python访问hdfs的操作
Jun 06 Python
高考考python编程是真的吗
Jul 20 Python
Python threading模块condition原理及运行流程详解
Oct 05 Python
Python内置包对JSON文件数据进行编码和解码
Apr 12 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
Php部分常见问题总结
2006/10/09 PHP
防止MySQL注入或HTML表单滥用的PHP程序
2009/01/21 PHP
细谈php中SQL注入攻击与XSS攻击
2012/06/10 PHP
php异步多线程swoole用法实例
2014/11/14 PHP
php通过Chianz.com获取IP地址与地区的方法
2015/01/14 PHP
十幅图告诉你什么是PHP引用
2015/02/22 PHP
PHP基本语法实例总结
2016/09/09 PHP
Laravel 5.5 的自定义验证对象/类示例代码详解
2017/08/29 PHP
js 颜色选择器(兼容firefox)
2009/03/05 Javascript
原生js 秒表实现代码
2012/07/24 Javascript
js中parseInt函数浅谈
2013/07/31 Javascript
一款由jquery实现的整屏切换特效
2014/09/15 Javascript
js获取页面及个元素高度、宽度的代码
2016/04/26 Javascript
DOM中事件处理概览与原理的全面解析
2016/08/16 Javascript
原生js封装的一些jquery方法(详解)
2016/09/20 Javascript
移动端滑动插件Swipe教程
2016/10/16 Javascript
Bootstrap入门教程一Hello Bootstrap初识
2017/03/02 Javascript
js实现分页功能
2017/05/24 Javascript
Python求两个文本文件以行为单位的交集、并集与差集的方法
2015/06/17 Python
Python基于回溯法子集树模板解决m着色问题示例
2017/09/07 Python
Python3实战之爬虫抓取网易云音乐的热门评论
2017/10/09 Python
python如何在循环引用中管理内存
2018/03/20 Python
python三大神器之fabric使用教程
2019/06/10 Python
Python的bit_length函数来二进制的位数方法
2019/08/27 Python
自我鉴定范文
2013/11/10 职场文书
2014副局长群众路线对照检查材料思想汇报
2014/09/22 职场文书
综治工作汇报材料
2014/10/27 职场文书
大学生预备党员自我评价
2015/03/04 职场文书
保护环境建议书作文500字
2015/09/14 职场文书
实习员工转正的评语汇总,以备不时之需
2019/12/17 职场文书
sql注入教程之类型以及提交注入
2021/08/02 MySQL
vue如何使用模拟的json数据查看效果
2022/03/31 Vue.js
Ruby GDBM操作简介及数据存储原理
2022/04/19 Ruby
选购到合适的激光打印机
2022/04/21 数码科技
python使用BeautifulSoup 解析HTML
2022/04/24 Python
nginx 添加http_stub_status_module模块
2022/05/25 Servers