tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib2模块获取gravatar头像实例
Dec 18 Python
使用70行Python代码实现一个递归下降解析器的教程
Apr 17 Python
简介二分查找算法与相关的Python实现示例
Aug 26 Python
python实现简单中文词频统计示例
Nov 08 Python
python删除本地夹里重复文件的方法
Nov 19 Python
python3解析库pyquery的深入讲解
Jun 26 Python
python的继承知识点总结
Dec 10 Python
pyqt 实现为长内容添加滑轮 scrollArea
Jun 19 Python
TensorFlow实现打印每一层的输出
Jan 21 Python
jupyter notebook 调用环境中的Keras或者pytorch教程
Apr 14 Python
Pygame的程序开始示例代码
May 07 Python
Python中bisect的用法及示例详解
Jul 20 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
php批量删除数据
2007/01/18 PHP
简单谈谈favicon
2015/06/10 PHP
详解PHP中mb_strpos的使用
2018/02/04 PHP
FCK调用方法..
2006/12/21 Javascript
HTML中Select不用Disabled实现ReadOnly的效果
2008/04/07 Javascript
说明你的Javascript技术很烂的五个原因
2011/04/26 Javascript
关于二级域名下使用一级域名下的COOKIE的问题
2011/11/07 Javascript
JavaScript插入动态样式实现代码
2012/02/22 Javascript
javascript函数命名的三种方式及区别介绍
2016/03/22 Javascript
微信小程序 教程之引用
2016/10/18 Javascript
vuejs2.0子组件改变父组件的数据实例
2017/05/10 Javascript
微信小程序实现获取自己所处位置的经纬度坐标功能示例
2017/11/30 Javascript
详解vue2.0+vue-video-player实现hls播放全过程
2018/03/02 Javascript
[01:32]2014DOTA2西雅图邀请赛 CIS我们有信心进入正赛
2014/07/08 DOTA
[01:01:24]LGD vs Fnatic 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python简单的函数定义和用法实例
2015/05/07 Python
python定时器(Timer)用法简单实例
2015/06/04 Python
用Python设计一个经典小游戏
2017/05/15 Python
Sanic框架流式传输操作示例
2018/07/18 Python
python实现发送form-data数据的方法详解
2019/09/27 Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
2019/12/31 Python
canvas之万花筒效果的简单实现(推荐)
2016/08/16 HTML / CSS
Html5内唤醒百度、高德APP的实现示例
2019/05/20 HTML / CSS
美国在线纱线商店:Darn Good Yarn
2019/03/20 全球购物
英国PC组件和在线电脑商店:SCAN
2019/04/18 全球购物
Engel & Bengel官网:婴儿推车、儿童房家具和婴儿设备
2019/12/28 全球购物
GOLFINO英国官网:高尔夫服装
2020/04/11 全球购物
static关键字的用法
2013/10/07 面试题
北京麒麟网信息技术有限公司网络游戏测试面试题
2013/09/28 面试题
大一学生假期实习的自我评价
2013/10/12 职场文书
授权委托书格式模板
2014/04/03 职场文书
2014年人事专员工作总结
2014/11/19 职场文书
优秀护士事迹材料
2014/12/25 职场文书
工作保证书怎么写
2015/02/28 职场文书
单位工作证明范本
2015/06/15 职场文书
Java常用函数式接口总结
2021/06/29 Java/Android