tensorflow入门:TFRecordDataset变长数据的batch读取详解


Posted in Python onJanuary 20, 2020

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从一组颜色中找出与给定颜色最接近颜色的方法
Mar 19 Python
Python的Flask框架与数据库连接的教程
Apr 20 Python
对python多线程SSH登录并发脚本详解
Feb 14 Python
Python面向对象程序设计类的多态用法详解
Apr 12 Python
pyqt实现.ui文件批量转换为对应.py文件脚本
Jun 19 Python
python3.7 的新特性详解
Jul 25 Python
Python:二维列表下标互换方式(矩阵转置)
Dec 02 Python
python读取与处理netcdf数据方式
Feb 14 Python
Django 如何使用日期时间选择器规范用户的时间输入示例代码详解
May 22 Python
Python3通过chmod修改目录或文件权限的方法示例
Jun 08 Python
python开根号实例讲解
Aug 30 Python
matplotlib自定义鼠标光标坐标格式的实现
Jan 08 Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
You might like
php判断输入是否是纯数字,英文,汉字的方法
2015/03/05 PHP
详解PHP中的状态模式编程
2015/08/11 PHP
PHP实现发送邮件的方法(基于简单邮件发送类)
2015/12/17 PHP
Symfony2框架学习笔记之HTTP Cache用法详解
2016/03/18 PHP
PHP用FTP类上传文件视频等的简单实现方法
2016/09/23 PHP
PHP程序员简单的开展服务治理架构操作详解(二)
2020/05/14 PHP
JS 自动安装exe程序
2008/11/30 Javascript
获取HTML DOM节点元素的方法的总结
2009/08/21 Javascript
javascript自然分类法算法实现代码
2013/10/11 Javascript
JavaScript中eval函数的问题
2016/01/31 Javascript
js运动事件函数详解
2016/10/21 Javascript
微信小程序 限制1M的瘦身技巧与方法详解
2017/01/06 Javascript
微信小程序tabbar不显示解决办法
2017/06/08 Javascript
jQuery实现手势解锁密码特效
2017/08/14 jQuery
vue综合组件间的通信详解
2017/11/06 Javascript
前端html中jQuery实现对文本的搜索功能并把搜索相关内容显示出来
2017/11/14 jQuery
微信小程序实现点击按钮修改字体颜色功能【附demo源码下载】
2017/12/05 Javascript
微信小程序chooseImage的用法(从本地相册选择图片或使用相机拍照)
2018/08/22 Javascript
JS面向对象编程实现的拖拽功能案例详解
2020/03/03 Javascript
Python ORM框架SQLAlchemy学习笔记之数据添加和事务回滚介绍
2014/06/10 Python
python处理xml文件的方法小结
2017/05/02 Python
python用opencv批量截取图像指定区域的方法
2019/01/24 Python
Python装饰器原理与基本用法分析
2020/01/07 Python
Django 实现 Websocket 广播、点对点发送消息的代码
2020/06/03 Python
Python操作Elasticsearch处理timeout超时
2020/07/17 Python
英国时尚饰品和发饰购物网站:Claire’s
2017/07/04 全球购物
《一件运动衫》教学反思
2014/02/19 职场文书
喝酒检查书范文
2014/02/23 职场文书
社区工作者演讲稿
2014/05/23 职场文书
2013年最新自荐信范文
2014/06/23 职场文书
2014年乡镇安全生产工作总结
2014/12/02 职场文书
新学期开学寄语2016
2015/12/04 职场文书
工作计划范文之财务管理
2019/08/09 职场文书
Python爬虫中urllib3与urllib的区别是什么
2021/07/21 Python
python turtle绘制多边形和跳跃和改变速度特效
2022/03/16 Python
Django框架之路由用法
2022/06/10 Python