tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
itchat和matplotlib的结合使用爬取微信信息的实例
Aug 25 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
Python使用文件锁实现进程间同步功能【基于fcntl模块】
Oct 16 Python
python爬虫使用cookie登录详解
Dec 27 Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 Python
用Python PIL实现几个简单的图片特效
Jan 18 Python
python实现K近邻回归,采用等权重和不等权重的方法
Jan 23 Python
PyQt5的安装配置过程,将ui文件转为py文件后显示窗口的实例
Jun 19 Python
Python类中方法getitem和getattr详解
Aug 30 Python
Python模拟登录之滑块验证码的破解(实例代码)
Nov 18 Python
python 5个实用的技巧
Sep 27 Python
OpenCV全景图像拼接的实现示例
Jun 05 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
Android ProgressBar进度条和ProgressDialog进度框的展示DEMO
2013/06/19 PHP
PHP实现 APP端微信支付功能
2018/06/22 PHP
你可能不再需要JQUERY
2021/03/09 Javascript
JS日历 推荐
2006/12/03 Javascript
零基础学JavaScript最新动画教程+iso光盘下载
2008/01/22 Javascript
jquery异步循环获取功能实现代码
2010/09/19 Javascript
jquery判断浏览器类型的代码
2012/11/05 Javascript
如何使用jQUery获取选中radio对应的值(一句代码)
2013/06/03 Javascript
改变隐藏的input中value值的方法
2014/03/19 Javascript
提高jQuery性能优化的技巧
2015/08/03 Javascript
JavaScript实现定时隐藏与显示图片的方法
2015/08/06 Javascript
AngularJS ng-blur 指令详解及简单实例
2016/07/30 Javascript
BootStrap glyphicon图标无法显示的解决方法
2016/09/06 Javascript
Vuex之理解state的用法实例
2017/04/19 Javascript
使用Node.js搭建静态资源服务详细教程
2017/08/02 Javascript
React Native时间转换格式工具类分享
2017/10/24 Javascript
在vue中通过axios异步使用echarts的方法
2018/01/13 Javascript
NodeJS 中Stream 的基本使用
2018/07/30 NodeJs
微信小程序使用template标签实现五星评分功能
2018/11/03 Javascript
JavaScript实现选项卡效果的分析及步骤
2019/04/16 Javascript
layui表格 列自动适应大小失效的解决方法
2019/09/06 Javascript
你不可不知的Vue.js列表渲染详解
2019/10/01 Javascript
Python functools模块学习总结
2015/05/09 Python
Python3中使用PyMongo的方法详解
2017/07/28 Python
python获取当前目录路径和上级路径的实例
2018/04/26 Python
Python 使用 Pillow 模块给图片添加文字水印的方法
2019/08/30 Python
JAVA SWT事件四种写法实例解析
2020/06/05 Python
推荐10个CSS3 制作的创意下拉菜单效果
2014/02/11 HTML / CSS
用CSS3写的模仿iPhone中的返回按钮
2015/04/04 HTML / CSS
英国时尚家具、家居饰品及礼品商店:Graham & Green
2016/09/15 全球购物
TCP/IP的分层模型
2013/10/27 面试题
支教自我鉴定
2014/01/18 职场文书
党员承诺书格式
2014/05/21 职场文书
医院2014国庆节活动策划方案
2014/09/21 职场文书
浅谈pytorch中stack和cat的及to_tensor的坑
2021/05/20 Python
Python+Selenium自动化环境搭建与操作基础详解
2022/03/13 Python