tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中dict与set的使用
Aug 10 Python
详解python的webrtc库实现语音端点检测
May 31 Python
简单了解python模块概念
Jan 11 Python
python利用socketserver实现并发套接字功能
Jan 26 Python
pandas数据预处理之dataframe的groupby操作方法
Apr 13 Python
详解Python发送email的三种方式
Oct 18 Python
pandas读取CSV文件时查看修改各列的数据类型格式
Jul 07 Python
使用python打印十行杨辉三角过程详解
Jul 10 Python
python+numpy实现的基本矩阵操作示例
Jul 19 Python
python lambda表达式(匿名函数)写法解析
Sep 16 Python
python基于gevent实现并发下载器代码实例
Nov 01 Python
Python切片列表字符串如何实现切换
Aug 06 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
php抓取https的内容的代码
2010/04/06 PHP
smarty模板引擎之分配数据类型
2015/03/30 PHP
PHP pthreads v3下同步处理synchronized用法示例
2020/02/21 PHP
use jscript List Installed Software
2007/06/11 Javascript
jQuery 学习第五课 Ajax 使用说明
2010/05/17 Javascript
jsp+javascript打造级连菜单的实例代码
2013/06/14 Javascript
百度移动版的url编码解码示例
2014/04/29 Javascript
JavaScript实现简单图片滚动附源码下载
2014/06/17 Javascript
百度地图自定义控件分享
2015/03/04 Javascript
jQuery实现分章节锚点“回到顶部”动画特效代码
2015/10/23 Javascript
解决Window10系统下Node安装报错的问题分析
2016/12/13 Javascript
工作中常用的js、jquery自定义扩展函数代码片段汇总
2016/12/22 Javascript
js常用DOM方法详解
2017/02/04 Javascript
JS中如何实现Laravel的route函数详解
2017/02/12 Javascript
详解mpvue中使用vant时需要注意的onChange事件的坑
2019/05/16 Javascript
Vue基础学习之项目整合及优化
2019/06/02 Javascript
微信小程序获取公众号文章列表及显示文章的示例代码
2020/03/10 Javascript
vue 重塑数组之修改数组指定index的值操作
2020/08/09 Javascript
JavaScript事件委托实现原理及优点进行
2020/08/29 Javascript
Python的Tornado框架的异步任务与AsyncHTTPClient
2016/06/27 Python
Python入门之三角函数tan()函数实例详解
2017/11/08 Python
用python处理MS Word的实例讲解
2018/05/08 Python
Python 字节流,字符串,十六进制相互转换实例(binascii,bytes)
2020/05/11 Python
python中openpyxl和xlsxwriter对Excel的操作方法
2021/03/01 Python
什么是ARP(Address Resolution Protocol)地址解析协议
2013/10/31 面试题
2013年入党人员的自我鉴定
2013/10/25 职场文书
初中生学习生活的自我评价
2013/11/20 职场文书
企业安全生产责任书
2014/04/14 职场文书
演讲稿格式范文
2014/05/19 职场文书
班主任2015新年寄语
2014/12/08 职场文书
办公室个人总结
2015/02/28 职场文书
学校清洁工岗位职责
2015/04/15 职场文书
义诊活动通知
2015/04/24 职场文书
企业党员岗位承诺书
2015/04/27 职场文书
Python爬虫之自动爬取某车之家各车销售数据
2021/06/02 Python
VMware虚拟机安装 Windows Server 2022的详细图文教程
2022/09/23 Servers