tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用


Posted in Python onJanuary 20, 2020

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
exmp_serial = exmp.SerializeToString()  #Example序列化
tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件
tfrecord_wrt.close()  #写完后关闭tfrecord的writer

代码汇总:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
#把数据写入Example
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 for inx in range(ndatas):
 exmp = get_tfrecords_example(feats[inx], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:

dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:

iterator = dataset.make_one_shot_iterator()
batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:

with tf.Session() as sess:
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
train_op, loss, eval_op = model(x, y_)
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_test = summary_op('val')
 
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  feed_dict={handle: handle_train, keep_prob: 0.5} )
    cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
  feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator, i.e. iterator placeholder
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
 
# cnn
x_image = tf.reshape(x, [-1,28,28,1])
w_init = tf.truncated_normal_initializer(stddev=0.1, seed=9)
b_init = tf.constant_initializer(0.1)
cnn1 = tf.layers.conv2d(x_image, 32, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides=2, padding='same')
cnn2 = tf.layers.conv2d(mxpl1, 64, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides=2, padding='same')
mxpl2_flat = tf.reshape(mxpl2, [-1,7*7*64])
fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
keep_prob = tf.placeholder('float')
fc1_drop = tf.nn.dropout(fc1, keep_prob)
logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init)
 
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
optmz = tf.train.AdamOptimizer(1e-4)
train_op = optmz.minimize(loss)
 
def get_eval_op(logits, labels):
 corr_prd = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
 return tf.reduce_mean(tf.cast(corr_prd, 'float'))
eval_op = get_eval_op(logits, y_)
 
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_val = summary_op('val')
 
# whether to restore or not
ckpts_dir = 'ckpts/'
ckpt_nm = 'cnn-ckpt'
saver = tf.train.Saver(max_to_keep=50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.
restore_step = ''
start_step = 0
train_steps = nBatchs
best_loss = 1e6
best_step = 0
 
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# config = tf.ConfigProto() 
# config.gpu_options.per_process_gpu_memory_fraction = 0.9
# config.gpu_options.allow_growth=True # allocate when needed
# with tf.Session(config=config) as sess:
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
 if restore_step:
 ckpt = tf.train.get_checkpoint_state(ckpts_dir)
 if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt
  if restore_step == 'latest':
  ckpt_f = tf.train.latest_checkpoint(ckpts_dir)
  start_step = int(ckpt_f.split('-')[-1]) + 1
  else:
  ckpt_f = ckpts_dir+ckpt_nm+'-'+restore_step
  print('loading wgt file: '+ ckpt_f)
  saver.restore(sess, ckpt_f) 
 summary_wrt = tf.summary.FileWriter(logdir,sess.graph)
 if restore_step in ['', 'latest']:
 for i in range(start_step, train_steps):
  _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
   feed_dict={handle: handle_train, keep_prob: 0.5} )
  # log to stdout and eval validation set
  if i % 100 == 0 or i == train_steps-1:
  saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables
  summary_wrt.add_summary(summary, global_step=i)
  cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \
   feed_dict={handle: handle_val, keep_prob: 1.0})
  if cur_val_loss < best_loss:
   best_loss = cur_val_loss
   best_step = i
  summary_wrt.add_summary(summary, global_step=i)
  print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \
   cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)
  # sess.run(init_train)
 with open(ckpts_dir+'best.step','w') as f:
  f.write('best step is %d\n'%best_step)
 print 'best step is %d'%best_step
 # eval test set
 test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0})
 print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:

tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

以上这篇tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的re模块应用实例
Sep 26 Python
python中日期和时间格式化输出的方法小结
Mar 19 Python
Go语言基于Socket编写服务器端与客户端通信的实例
Feb 19 Python
Python3处理HTTP请求的实例
May 10 Python
python 对txt中每行内容进行批量替换的方法
Jul 11 Python
Python线程池模块ThreadPoolExecutor用法分析
Dec 28 Python
Python中字符串List按照长度排序
Jul 01 Python
pycharm 2019 最新激活方式(pycharm破解、激活)
Sep 22 Python
python使用paramiko实现ssh的功能详解
Mar 06 Python
Python执行时间的几种计算方法
Jul 31 Python
Pycharm无法打开双击没反应的问题及解决方案
Aug 17 Python
5行Python代码实现一键批量扣图
Jun 29 Python
tensorflow入门:TFRecordDataset变长数据的batch读取详解
Jan 20 #Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
You might like
解析php dirname()与__FILE__常量的应用
2013/06/24 PHP
php中实现记住密码下次自动登录的例子
2014/11/06 PHP
PHP实现抓取HTTPS内容
2014/12/01 PHP
Zend Framework动作助手(Zend_Controller_Action_Helper)用法详解
2016/03/05 PHP
javascript据option的value值快速设定初始的selected选项
2007/08/13 Javascript
验证javascript中Object和Function的关系的三段简单代码
2010/06/27 Javascript
关于二级域名下使用一级域名下的COOKIE的问题
2011/11/07 Javascript
jquery实现的用户注册表单提示操作效果代码分享
2015/08/28 Javascript
javascript正则表达式中分组详解
2016/07/17 Javascript
BootStrap iCheck插件全选与获取value值的解决方法
2016/08/24 Javascript
vue.js实现条件渲染的实例代码
2017/06/22 Javascript
关于axios不能使用Vue.use()浅析
2018/01/12 Javascript
vue-router 实现导航守卫(路由卫士)的实例代码
2018/09/02 Javascript
详解Vue源码之数据的代理访问
2018/12/11 Javascript
vue实现动态显示与隐藏底部导航的方法分析
2019/02/11 Javascript
详解JavaScript 异步编程
2020/07/13 Javascript
[01:05:30]VP vs TNC 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
[01:06:19]DOTA2-DPC中国联赛定级赛 LBZS vs SAG BO3第二场 1月8日
2021/03/11 DOTA
详解python进行mp3格式判断
2016/12/23 Python
python Celery定时任务的示例
2018/03/13 Python
numpy返回array中元素的index方法
2018/06/27 Python
利用python将图片版PDF转文字版PDF
2019/05/03 Python
Python使用Pandas库实现MySQL数据库的读写
2019/07/06 Python
利用Python模拟登录pastebin.com的实现方法
2019/07/12 Python
selenium2.0中常用的python函数汇总
2019/08/05 Python
python SocketServer源码深入解读
2019/09/17 Python
CSS3实现各种图形的示例代码
2016/10/19 HTML / CSS
Canvas 文本转粒子效果的实现代码
2019/02/14 HTML / CSS
捷克电器和DJ设备网上商店:Electronic-star
2017/07/18 全球购物
施华洛世奇西班牙官网:SWAROVSKI西班牙
2019/06/06 全球购物
自我鉴定注意事项
2014/01/19 职场文书
致铅球运动员加油稿
2014/02/13 职场文书
安全目标责任书
2014/07/22 职场文书
垃圾分类的活动方案
2014/08/15 职场文书
Win11跳过联网界面创建本地管理账户的3种方法
2022/04/20 数码科技
Python何绘制带有背景色块的折线图
2022/04/23 Python