tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用


Posted in Python onJanuary 20, 2020

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
exmp_serial = exmp.SerializeToString()  #Example序列化
tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件
tfrecord_wrt.close()  #写完后关闭tfrecord的writer

代码汇总:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
#把数据写入Example
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 for inx in range(ndatas):
 exmp = get_tfrecords_example(feats[inx], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:

dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度
feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)})
image = tf.decode_raw(feats['feature'], tf.float32)
label = feats['label']
shape = tf.cast(feats['shape'], tf.int32)
return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:

iterator = dataset.make_one_shot_iterator()
batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:

with tf.Session() as sess:
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
train_op, loss, eval_op = model(x, y_)
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_test = summary_op('val')
 
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  feed_dict={handle: handle_train, keep_prob: 0.5} )
    cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
  feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator, i.e. iterator placeholder
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
 
# cnn
x_image = tf.reshape(x, [-1,28,28,1])
w_init = tf.truncated_normal_initializer(stddev=0.1, seed=9)
b_init = tf.constant_initializer(0.1)
cnn1 = tf.layers.conv2d(x_image, 32, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides=2, padding='same')
cnn2 = tf.layers.conv2d(mxpl1, 64, (5,5), padding='same', activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides=2, padding='same')
mxpl2_flat = tf.reshape(mxpl2, [-1,7*7*64])
fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \
 kernel_initializer=w_init, bias_initializer=b_init)
keep_prob = tf.placeholder('float')
fc1_drop = tf.nn.dropout(fc1, keep_prob)
logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init)
 
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
optmz = tf.train.AdamOptimizer(1e-4)
train_op = optmz.minimize(loss)
 
def get_eval_op(logits, labels):
 corr_prd = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
 return tf.reduce_mean(tf.cast(corr_prd, 'float'))
eval_op = get_eval_op(logits, y_)
 
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_val = summary_op('val')
 
# whether to restore or not
ckpts_dir = 'ckpts/'
ckpt_nm = 'cnn-ckpt'
saver = tf.train.Saver(max_to_keep=50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.
restore_step = ''
start_step = 0
train_steps = nBatchs
best_loss = 1e6
best_step = 0
 
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# config = tf.ConfigProto() 
# config.gpu_options.per_process_gpu_memory_fraction = 0.9
# config.gpu_options.allow_growth=True # allocate when needed
# with tf.Session(config=config) as sess:
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
 if restore_step:
 ckpt = tf.train.get_checkpoint_state(ckpts_dir)
 if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt
  if restore_step == 'latest':
  ckpt_f = tf.train.latest_checkpoint(ckpts_dir)
  start_step = int(ckpt_f.split('-')[-1]) + 1
  else:
  ckpt_f = ckpts_dir+ckpt_nm+'-'+restore_step
  print('loading wgt file: '+ ckpt_f)
  saver.restore(sess, ckpt_f) 
 summary_wrt = tf.summary.FileWriter(logdir,sess.graph)
 if restore_step in ['', 'latest']:
 for i in range(start_step, train_steps):
  _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
   feed_dict={handle: handle_train, keep_prob: 0.5} )
  # log to stdout and eval validation set
  if i % 100 == 0 or i == train_steps-1:
  saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables
  summary_wrt.add_summary(summary, global_step=i)
  cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \
   feed_dict={handle: handle_val, keep_prob: 1.0})
  if cur_val_loss < best_loss:
   best_loss = cur_val_loss
   best_step = i
  summary_wrt.add_summary(summary, global_step=i)
  print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \
   cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)
  # sess.run(init_train)
 with open(ckpts_dir+'best.step','w') as f:
  f.write('best step is %d\n'%best_step)
 print 'best step is %d'%best_step
 # eval test set
 test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0})
 print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:

tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

以上这篇tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时采集摄像头图像上传ftp服务器功能实现
Dec 23 Python
Python+matplotlib实现填充螺旋实例
Jan 15 Python
python如何让类支持比较运算
Mar 20 Python
Python实现将Excel转换成xml的方法示例
Aug 25 Python
用Python批量把文件复制到另一个文件夹的实现方法
Aug 16 Python
Python常用模块logging——日志输出功能(示例代码)
Nov 20 Python
python使用PIL剪切和拼接图片
Mar 23 Python
Python docutils文档编译过程方法解析
Jun 23 Python
python实现感知机模型的示例
Sep 30 Python
python3中确保枚举值代码分析
Dec 02 Python
基于python+selenium自动健康打卡的实现代码
Jan 13 Python
python sleep和wait对比总结
Feb 03 Python
tensorflow入门:TFRecordDataset变长数据的batch读取详解
Jan 20 #Python
python如何通过pyqt5实现进度条
Jan 20 #Python
python super用法及原理详解
Jan 20 #Python
tensorflow 变长序列存储实例
Jan 20 #Python
在tensorflow中实现去除不足一个batch的数据
Jan 20 #Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 #Python
Python线程条件变量Condition原理解析
Jan 20 #Python
You might like
php smarty truncate UTF8乱码问题解决办法
2014/06/13 PHP
php rsa加密解密使用详解
2015/01/14 PHP
PHP设置进度条的方法
2015/07/08 PHP
详解WordPress中提醒安装插件以及隐藏插件的功能实现
2015/12/25 PHP
javascript中的prototype属性使用说明(函数功能扩展)
2010/08/16 Javascript
初识Node.js
2015/03/20 Javascript
js实现网页抽奖实例
2015/08/05 Javascript
jQuery中prepend()方法使用详解
2015/08/11 Javascript
JavaScript必知必会(九)function 说起 闭包问题
2016/06/08 Javascript
js实现楼层效果的简单实例
2016/07/15 Javascript
使用微信内嵌H5网页解决JS倒计时失效问题
2017/01/13 Javascript
iscroll-probe实现下拉刷新和下拉加载效果
2017/06/28 Javascript
Django中使用jquery的ajax进行数据交互的实例代码
2017/10/15 jQuery
JS基于Location实现访问Url、重定向及刷新页面的方法分析
2018/12/03 Javascript
详解vue更改头像功能实现
2019/04/28 Javascript
原生js实现无缝轮播图
2020/01/11 Javascript
小程序websocket心跳库(websocket-heartbeat-miniprogram)
2020/02/23 Javascript
vue 使用饿了么UI仿写teambition的筛选功能
2021/03/01 Vue.js
python定时采集摄像头图像上传ftp服务器功能实现
2013/12/23 Python
Flask框架中密码的加盐哈希加密和验证功能的用法详解
2016/06/07 Python
python 函数传参之传值还是传引用的分析
2017/09/07 Python
Django之Mode的外键自关联和引用未定义的Model方法
2018/12/15 Python
详解python中的time和datetime的常用方法
2019/07/08 Python
在jupyter notebook 添加 conda 环境的操作详解
2020/04/10 Python
tensorflow实现将ckpt转pb文件的方法
2020/04/22 Python
Tensorflow全局设置可见GPU编号操作
2020/06/30 Python
德国BA保镖药房中文网:Bodyguard Apotheke
2021/03/09 全球购物
片区教研活动总结
2014/07/02 职场文书
开展创先争优活动总结
2014/08/28 职场文书
基层党员四风问题自我剖析材料
2014/09/29 职场文书
党的群众路线教育实践活动学习笔记
2014/11/05 职场文书
2014大学生学生会工作总结
2014/12/19 职场文书
2015年安全员工作总结范文
2015/04/22 职场文书
Python 数据可视化之Bokeh详解
2021/11/02 Python
Python使用socket去实现TCP客户端和TCP服务端
2022/04/12 Python
Python中time标准库的使用教程
2022/04/13 Python