tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python用来获得图片exif信息的库实例分析
Mar 16 Python
Python实现图像几何变换
Jul 06 Python
在django中使用自定义标签实现分页功能
Jul 04 Python
Python学习笔记之if语句的使用示例
Oct 23 Python
关于Django显示时间你应该知道的一些问题
Dec 25 Python
Python 类的特殊成员解析
Jun 20 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 Python
python3 读取Excel表格中的数据
Oct 16 Python
Python assert关键字原理及实例解析
Dec 13 Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 Python
Python爬虫获取op.gg英雄联盟英雄对位胜率的源码
Jan 29 Python
C站最全Python标准库总结,你想要的都在这里
Jul 03 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
php 安全过滤函数代码
2011/05/07 PHP
浅谈php中fopen不能创建中文文件名文件的问题
2017/02/06 PHP
Django中的cookie与session操作实例代码
2017/08/17 PHP
任意位置显示html菜单
2007/02/01 Javascript
javascript实现上传图片并预览的效果实现代码
2011/04/11 Javascript
Extjs 3.3切换tab隐藏相应工具栏出现空白解决
2013/04/02 Javascript
Extjs3.0 checkboxGroup 动态添加item实现思路
2013/08/14 Javascript
在HTML代码中使用JavaScript代码的例子
2014/10/16 Javascript
node.js中Socket.IO的进阶使用技巧
2014/11/04 Javascript
JavaScript类型系统之Object详解
2016/01/07 Javascript
JavaScript中关联原型链属性特性
2016/02/13 Javascript
Jquery获取当前城市的天气信息
2016/08/05 Javascript
AngularJS全局scope与Isolate scope通信用法示例
2016/11/22 Javascript
详解Vue.js入门环境搭建
2017/03/17 Javascript
Vue.js实现一个漂亮、灵活、可复用的提示组件示例
2017/03/17 Javascript
js实现左右两侧浮动广告
2018/07/09 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
vue-cli3项目展示本地Markdown文件的方法
2019/06/07 Javascript
bootstrap-treeview实现多级树形菜单 后台JSON格式如何组织?
2019/07/26 Javascript
JavaScript实现tab栏切换效果
2020/03/16 Javascript
[10:18]2018DOTA2国际邀请赛寻真——Fnatic能否笑到最后?
2018/08/14 DOTA
Ubuntu下安装PyV8
2016/03/13 Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
2018/12/29 Python
对pandas处理json数据的方法详解
2019/02/08 Python
PyQt5连接MySQL及QMYSQL driver not loaded错误解决
2020/04/29 Python
HTML5 语义化结构化规范化
2008/10/17 HTML / CSS
印尼穆斯林时尚购物网站:Hijabenka
2016/12/10 全球购物
Expedia印度尼西亚站:预订酒店、廉价航班和度假套餐
2018/01/31 全球购物
学前教育学生自荐信范文
2013/12/31 职场文书
物流仓管员工作职责
2014/01/06 职场文书
群众路线党课主持词
2014/04/01 职场文书
公安民警正风肃纪剖析材料
2014/10/10 职场文书
车间统计员岗位职责
2015/04/14 职场文书
python基础详解之if循环语句
2021/04/24 Python
一小时迅速入门Mybatis之bind与多数据源支持 Java API
2021/09/15 Javascript
redis sentinel监控高可用集群实现的配置步骤
2022/04/01 Redis