tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用


Posted in Python onJanuary 20, 2020

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
exmp_serial = exmp.SerializeToString()  #Example序列化
tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件
tfrecord_wrt.close()  #写完后关闭tfrecord的writer

代码汇总:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
#把数据写入Example
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 for inx in range(ndatas):
 exmp = get_tfrecords_example(feats[inx], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:

dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:

iterator = dataset.make_one_shot_iterator()
batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:

with tf.Session() as sess:
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
train_op, loss, eval_op = model(x, y_)
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_test = summary_op('val')
 
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  feed_dict={handle: handle_train, keep_prob: 0.5} )
    cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
  feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator, i.e. iterator placeholder
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
 
# cnn
x_image = tf.reshape(x, [-1,28,28,1])
w_init = tf.truncated_normal_initializer(stddev=0.1, seed=9)
b_init = tf.constant_initializer(0.1)
cnn1 = tf.layers.conv2d(x_image, 32, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides=2, padding='same')
cnn2 = tf.layers.conv2d(mxpl1, 64, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides=2, padding='same')
mxpl2_flat = tf.reshape(mxpl2, [-1,7*7*64])
fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
keep_prob = tf.placeholder('float')
fc1_drop = tf.nn.dropout(fc1, keep_prob)
logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init)
 
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
optmz = tf.train.AdamOptimizer(1e-4)
train_op = optmz.minimize(loss)
 
def get_eval_op(logits, labels):
 corr_prd = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
 return tf.reduce_mean(tf.cast(corr_prd, 'float'))
eval_op = get_eval_op(logits, y_)
 
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_val = summary_op('val')
 
# whether to restore or not
ckpts_dir = 'ckpts/'
ckpt_nm = 'cnn-ckpt'
saver = tf.train.Saver(max_to_keep=50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.
restore_step = ''
start_step = 0
train_steps = nBatchs
best_loss = 1e6
best_step = 0
 
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# config = tf.ConfigProto() 
# config.gpu_options.per_process_gpu_memory_fraction = 0.9
# config.gpu_options.allow_growth=True # allocate when needed
# with tf.Session(config=config) as sess:
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
 if restore_step:
 ckpt = tf.train.get_checkpoint_state(ckpts_dir)
 if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt
  if restore_step == 'latest':
  ckpt_f = tf.train.latest_checkpoint(ckpts_dir)
  start_step = int(ckpt_f.split('-')[-1]) + 1
  else:
  ckpt_f = ckpts_dir+ckpt_nm+'-'+restore_step
  print('loading wgt file: '+ ckpt_f)
  saver.restore(sess, ckpt_f) 
 summary_wrt = tf.summary.FileWriter(logdir,sess.graph)
 if restore_step in ['', 'latest']:
 for i in range(start_step, train_steps):
  _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
   feed_dict={handle: handle_train, keep_prob: 0.5} )
  # log to stdout and eval validation set
  if i % 100 == 0 or i == train_steps-1:
  saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables
  summary_wrt.add_summary(summary, global_step=i)
  cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \
   feed_dict={handle: handle_val, keep_prob: 1.0})
  if cur_val_loss < best_loss:
   best_loss = cur_val_loss
   best_step = i
  summary_wrt.add_summary(summary, global_step=i)
  print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \
   cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)
  # sess.run(init_train)
 with open(ckpts_dir+'best.step','w') as f:
  f.write('best step is %d\n'%best_step)
 print 'best step is %d'%best_step
 # eval test set
 test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0})
 print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:

tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

以上这篇tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python中的wxPython实现最基本的浏览器功能
Apr 14 Python
Python中time模块和datetime模块的用法示例
Feb 28 Python
Win10下Python环境搭建与配置教程
Nov 18 Python
python负载均衡的简单实现方法
Feb 04 Python
Python通过调用mysql存储过程实现更新数据功能示例
Apr 03 Python
Anaconda2下实现Python2.7和Python3.5的共存方法
Jun 11 Python
浅析python3字符串格式化format()函数的简单用法
Dec 07 Python
django框架基于模板 生成 excel(xls) 文件操作示例
Jun 19 Python
对Python3之方法的覆盖与super函数详解
Jun 26 Python
Django中的用户身份验证示例详解
Aug 07 Python
Python常用数字处理基本操作汇总
Sep 10 Python
Django rest framework如何自定义用户表
Jun 09 Python
tensorflow入门:TFRecordDataset变长数据的batch读取详解
Jan 20 #Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
You might like
phpmyadmin3 安装配置图解教程
2012/03/29 PHP
免费手机号码归属地API查询接口和PHP使用实例分享
2014/04/10 PHP
利用php下载xls文件(自己动手写的)
2014/04/18 PHP
ThinkPHP实现带验证码的文件上传功能实例
2014/11/01 PHP
php微信开发自定义菜单
2016/08/27 PHP
PHP实现导出带样式的Excel
2016/08/28 PHP
laravel中短信发送验证码的实现方法
2018/04/25 PHP
Laravel 实现在Blade模版中使用全局变量代替路径的例子
2019/10/22 PHP
基于Jquery的动态添加控件并取值的实现代码
2010/09/24 Javascript
Jquery为单选框checkbox绑定单击click事件
2012/12/18 Javascript
JavaScript全排列的六种算法 具体实现
2013/06/29 Javascript
Javascript学习笔记之数组的遍历和 length 属性
2014/11/23 Javascript
JavaScript修改浏览器tab标题小技巧
2015/01/06 Javascript
jQuery实现返回顶部效果的方法
2015/05/29 Javascript
jQuery实现ctrl+enter(回车)提交表单
2015/10/19 Javascript
Javascript DOM事件操作小结(监听鼠标点击、释放,悬停、离开等)
2017/01/20 Javascript
JavaScript之生成器_动力节点Java学院整理
2017/06/30 Javascript
jQuery实现参数自定义的文字跑马灯效果
2018/08/15 jQuery
基于Nuxt.js项目的服务端性能优化与错误检测(容错处理)
2019/10/23 Javascript
在Vue中使用this.$store或者是$route一直报错的解决
2019/11/08 Javascript
请求时token过期自动刷新token操作
2020/09/11 Javascript
浅析Python中return和finally共同挖的坑
2017/08/18 Python
浅谈python中拼接路径os.path.join斜杠的问题
2018/10/23 Python
python使用knn实现特征向量分类
2018/12/26 Python
解决python 3 urllib 没有 urlencode 属性的问题
2019/08/22 Python
python异常处理try except过程解析
2020/02/03 Python
windows下的pycharm安装及其设置中文菜单
2020/04/23 Python
深入分析python 排序
2020/08/24 Python
python生成word合同的实例方法
2021/01/12 Python
德国咖啡批发商:Coffeefair
2019/08/26 全球购物
人力资源管理专业学生自我评价
2013/11/20 职场文书
《值日生》教学反思
2014/02/17 职场文书
技能竞赛活动方案
2014/02/21 职场文书
2014年三八妇女节活动总结
2014/03/01 职场文书
高中地理教学反思
2016/02/19 职场文书
python制作图形界面的2048游戏, 基于tkinter
2021/04/06 Python