tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中解析JSON并同时进行自定义编码处理实例
Feb 08 Python
Python获取某一天是星期几的方法示例
Jan 17 Python
Python基于正则表达式实现检查文件内容的方法【文件检索】
Aug 30 Python
Python简单实现的代理服务器端口映射功能示例
Apr 08 Python
基于Python 装饰器装饰类中的方法实例
Apr 21 Python
Python退火算法在高次方程的应用
Jul 26 Python
Django在admin后台集成TinyMCE富文本编辑器的例子
Aug 09 Python
用Python解数独的方法示例
Oct 24 Python
python网络编程之五子棋游戏
May 14 Python
Django model重写save方法及update踩坑详解
Jul 27 Python
python3从网络摄像机解析mjpeg http流的示例
Nov 13 Python
paramiko使用tail实时获取服务器的日志输出详解
Dec 06 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
什么是调频(FM)、调幅(AM)、短波(SW)、长波(LW)
2021/03/01 无线电
通俗易懂的php防注入代码
2010/04/07 PHP
利用php+mcDropdown实现文件路径可在下拉框选择
2013/08/07 PHP
学习php分页代码实例
2013/10/24 PHP
Symfony数据校验方法实例分析
2015/01/26 PHP
Yii使用migrate命令执行sql语句的方法
2016/03/15 PHP
PHP中的浅复制与深复制的实例详解
2017/10/26 PHP
PHP 代码简洁之道(小结)
2019/10/16 PHP
JavaScript 自动完成脚本整理(33个)
2009/10/20 Javascript
用jQuery中的ajax分页实现代码
2011/09/20 Javascript
jQuery中document与window以及load与ready 区别详解
2014/12/29 Javascript
Javascript核心读书有感之词法结构
2015/02/01 Javascript
javascript实现起伏的水波背景效果
2016/05/16 Javascript
jQuery解决input元素的blur事件和其他非表单元素的click事件冲突问题
2016/08/15 Javascript
微信小程序(应用号)简单实例应用及实例详解
2016/09/26 Javascript
jQuery实现的简单排序功能示例【冒泡排序】
2017/01/13 Javascript
利用策略模式与装饰模式扩展JavaScript表单验证功能
2017/02/14 Javascript
Nodejs进阶:express+session实现简易登录身份认证
2017/04/24 NodeJs
jQuery+SpringMVC中的复选框选择与传值实例
2018/01/08 jQuery
微信小程序仿朋友圈发布动态功能
2018/07/15 Javascript
vue调用本地摄像头实现拍照功能
2020/08/14 Javascript
[01:10]3.19DOTA2发布会 三代刀塔人第一代
2014/03/25 DOTA
Python编程中的异常处理教程
2015/08/21 Python
python: line=f.readlines()消除line中\n的方法
2018/03/19 Python
flask框架视图函数用法示例
2018/07/19 Python
python调用百度地图WEB服务API获取地点对应坐标值
2019/01/16 Python
python实现银行管理系统
2019/10/25 Python
Python3列表List入门知识附实例
2020/02/09 Python
Abbott Lyon官网:女士手表、珠宝及配件
2020/12/26 全球购物
机械绘图员岗位职责
2013/11/19 职场文书
服装发布会策划方案
2014/05/22 职场文书
六一儿童节演讲稿
2014/05/23 职场文书
比赛口号大全
2014/06/10 职场文书
古见同学有交流障碍症 第二季宣传CM公开播出
2022/04/11 日漫
SQL Server中的逻辑函数介绍
2022/05/25 SQL Server
python如何利用cv2.rectangle()绘制矩形框
2022/12/24 Python