tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python对字典进行排序实例
Sep 25 Python
Python 内置函数memoryview(obj)的具体用法
Nov 23 Python
Python实现购物车程序
Apr 16 Python
python和shell监控linux服务器的详细代码
Jun 22 Python
Python解决两个整数相除只得到整数部分的实例
Nov 10 Python
Python中面向对象你应该知道的一下知识
Jul 10 Python
Python pandas自定义函数的使用方法示例
Nov 20 Python
python为QT程序添加图标的方法详解
Mar 09 Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 Python
手把手教你将Flask应用封装成Docker服务的实现
Aug 19 Python
django注册用邮箱发送验证码的实现
Apr 18 Python
python入门之算法学习
Apr 22 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
探讨:如何通过stats命令分析Memcached的内部状态
2013/06/14 PHP
php中 $$str 中 "$$" 的详解
2015/07/06 PHP
php通过淘宝API查询IP地址归属等信息
2015/12/25 PHP
PHP获取二叉树镜像的方法
2018/01/17 PHP
TP5框架实现的数据库备份功能示例
2020/04/05 PHP
如何确保JavaScript的执行顺序 之实战篇
2011/03/03 Javascript
jQuery如何实现点击页面获得当前点击元素的id或其他信息
2014/01/09 Javascript
红米手机抢购的js代码
2014/03/10 Javascript
node.js中的buffer.Buffer.isBuffer方法使用说明
2014/12/14 Javascript
详解参数传递四种形式
2015/07/21 Javascript
jQuery图片轮播滚动切换代码分享
2020/04/20 Javascript
url中的特殊符号有什么含义(推荐)
2016/06/17 Javascript
浅谈Nodejs应用主文件index.js
2016/08/28 NodeJs
bootstrap读书笔记之CSS组件(上)
2016/10/17 Javascript
值得分享的Bootstrap Table使用教程
2016/11/23 Javascript
原生js开发的日历插件
2017/02/04 Javascript
BootStrap注意事项小结(五)表单
2017/03/10 Javascript
vue复合组件实现注册表单功能
2017/11/06 Javascript
vue两组件间值传递 $router.push实现方法
2019/05/15 Javascript
layui的select联动实现代码
2019/09/28 Javascript
js实现盒子滚动动画效果
2020/08/09 Javascript
vue+echarts实现动态折线图的方法与注意
2020/09/01 Javascript
vue3 watch和watchEffect的使用以及有哪些区别
2021/01/26 Vue.js
[48:48]完美世界DOTA2联赛PWL S3 Magama vs GXR 第一场 12.19
2020/12/24 DOTA
使用Python求解最大公约数的实现方法
2015/08/20 Python
使用Python的Flask框架构建大型Web应用程序的结构示例
2016/06/04 Python
Python3多线程操作简单示例
2018/05/22 Python
python实现全排列代码(回溯、深度优先搜索)
2020/02/26 Python
幼儿园中班开学寄语
2014/04/03 职场文书
实习生岗位职责
2014/04/12 职场文书
2014年公务员个人工作总结
2014/11/22 职场文书
接待员岗位职责
2015/02/13 职场文书
小学六年级毕业感言
2015/07/30 职场文书
微信小程序实现聊天室功能
2021/06/14 Javascript
python开发人人对战的五子棋小游戏
2022/05/02 Python
go goth封装第三方认证库示例详解
2022/08/14 Golang