tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 自动提交和抓取网页
Jul 13 Python
PyMongo安装使用笔记
Apr 27 Python
Python的SQLalchemy模块连接与操作MySQL的基础示例
Jul 11 Python
Python搭建HTTP服务器和FTP服务器
Mar 09 Python
python解决pandas处理缺失值为空字符串的问题
Apr 08 Python
Python3中lambda表达式与函数式编程讲解
Jan 14 Python
手写一个python迭代器过程详解
Aug 27 Python
python实现回旋矩阵方式(旋转矩阵)
Dec 04 Python
python3 使用openpyxl将mysql数据写入xlsx的操作
May 15 Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 Python
手把手教你如何用Pycharm2020.1.1配置远程连接的详细步骤
Aug 07 Python
pycharm激活码免费分享适用最新pycharm2020.2.3永久激活
Nov 25 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
php封装的数据库函数与用法示例【参考thinkPHP】
2016/11/08 PHP
PHP基于session.upload_progress 实现文件上传进度显示功能详解
2019/08/09 PHP
PHP设计模式之迭代器模式Iterator实例分析【对象行为型】
2020/04/26 PHP
jquery隐藏标签和显示标签的实例
2013/11/11 Javascript
jquery引用方法时传递参数原理分析
2014/10/13 Javascript
js QQ客服悬浮效果实现代码
2014/12/12 Javascript
javascript实现获取服务器时间
2015/05/19 Javascript
jQuery判断指定id的对象是否存在的方法
2015/05/22 Javascript
浅谈js 闭包引起的内存泄露问题
2015/06/22 Javascript
JavaScript中的操作符类型转换示例总结
2016/05/30 Javascript
JavaScript兼容性总结之获取非行间样式案例
2016/08/07 Javascript
jQuery 局部div刷新和全局刷新方法总结
2016/10/05 Javascript
Bootstrap中glyphicons-halflings-regular.woff字体报404错notfound的解决方法
2017/01/19 Javascript
angular中使用Socket.io实例代码
2017/06/03 Javascript
vue移动UI框架滑动加载数据的方法
2018/03/12 Javascript
解决layui前端框架 form表单,table表等内置控件不显示的问题
2018/08/19 Javascript
解决jQuery使用append添加的元素事件无效的问题
2018/08/30 jQuery
简化版的vue-router实现思路详解
2018/10/19 Javascript
用Vue.js在浏览器中实现裁剪图像功能
2019/06/18 Javascript
[44:26]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#4EG VS Fnatic第二局
2016/03/03 DOTA
Python文件和流(实例讲解)
2017/09/12 Python
Django入门使用示例
2017/12/12 Python
python实现简单遗传算法
2018/03/19 Python
值得收藏,Python 开发中的高级技巧
2018/11/23 Python
matplotlib命令与格式之tick坐标轴日期格式(设置日期主副刻度)
2019/08/06 Python
Jupyter Notebook 实现正常显示中文和负号
2020/04/24 Python
python3.7 openpyxl 在excel单元格中写入数据实例
2020/09/01 Python
详解修改Anaconda中的Jupyter Notebook默认工作路径的三种方式
2021/01/24 Python
泰国汽车、火车和轮渡票预订网站:Bus Online Ticket
2017/09/09 全球购物
广告设计专业自荐信范文
2013/11/14 职场文书
《李时珍夜宿古寺》教学反思
2014/04/09 职场文书
县委班子四风对照检查材料思想汇报
2014/09/29 职场文书
2014年心理健康教育工作总结
2014/12/06 职场文书
五一劳动节慰问信
2015/02/14 职场文书
教师工作能力自我评价
2015/03/04 职场文书
基于Apache Hudi在Google云构建数据湖平台的思路详解
2022/04/07 Servers