tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3写爬取B站视频弹幕功能
Dec 22 Python
详解Python核心编程中的浅拷贝与深拷贝
Jan 07 Python
windows环境下tensorflow安装过程详解
Mar 30 Python
opencv python 2D直方图的示例代码
Jul 20 Python
Python正则表达式和re库知识点总结
Feb 11 Python
Python中xml和dict格式转换的示例代码
Nov 07 Python
flask 使用 flask_apscheduler 做定时循环任务的实现
Dec 10 Python
Pytorch释放显存占用方式
Jan 13 Python
Python semaphore evevt生产者消费者模型原理解析
Mar 18 Python
Anaconda的安装及其环境变量的配置详解
Apr 22 Python
利用PyQt5+Matplotlib 绘制静态/动态图的实现代码
Jul 13 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
可以保证单词完整性的PHP英文字符串截取代码分享
2014/07/15 PHP
php树型类实例
2014/12/05 PHP
thinkPHP使用post方式查询时分页失效的解决方法
2015/12/09 PHP
thinkPHP实现递归循环栏目并按照树形结构无限极输出的方法
2016/05/19 PHP
WordPress过滤垃圾评论的几种主要方法小结
2016/07/11 PHP
你不知道的文件上传漏洞php代码分析
2016/09/29 PHP
PHP编辑器PhpStrom运行缓慢问题
2017/02/21 PHP
用JS实现一个页面多个css样式实现
2008/05/29 Javascript
JQuery里面的几种选择器 查找满足条件的元素$(&quot;#控件ID&quot;)
2011/08/23 Javascript
javascript面向对象编程代码
2011/12/19 Javascript
用Jquery重写windows.alert方法实现思路
2013/04/03 Javascript
javascript检测对象中是否存在某个属性判断方法小结
2013/05/19 Javascript
js获取数组的最后一个元素
2015/04/14 Javascript
jquery实现的3D旋转木马特效代码分享
2015/08/25 Javascript
JS+CSS实现仿支付宝菜单选中效果代码
2015/09/25 Javascript
JavaScript jquery及AJAX小结
2016/01/24 Javascript
js实现做通讯录的索引滑动显示效果和滑动显示锚点效果
2017/02/18 Javascript
从零学习node.js之mysql数据库的操作(五)
2017/02/24 Javascript
canvas绘制多边形
2017/02/24 Javascript
浅谈关于angularJs中使用$.ajax的注意点
2017/08/12 Javascript
解决vue+webpack打包路径的问题
2018/03/06 Javascript
JS 验证码功能的三种实现方式
2018/11/26 Javascript
Vuejs监听vuex中值的变化的方法示例
2018/12/02 Javascript
JS基于开关思想实现的数组去重功能【案例】
2019/02/18 Javascript
微信小程序如何再次获取用户授权的方法
2019/05/10 Javascript
JavaScript 实现轮播图特效的示例
2020/11/05 Javascript
深入学习Python中的装饰器使用
2016/06/20 Python
面向初学者的Python编辑器Mu
2018/10/08 Python
Win系统PyQt5安装和使用教程
2019/12/25 Python
如何使用python记录室友的抖音在线时间
2020/06/29 Python
HTML5 FormData 方法介绍以及实现文件上传示例
2017/09/12 HTML / CSS
详解移动端html5页面长按实现高亮全选文本内容的兼容解决方案
2016/12/03 HTML / CSS
西班牙手机之家:Phone House
2018/10/18 全球购物
计算机专业毕业生求职信
2014/04/30 职场文书
大学生团日活动总结
2015/05/06 职场文书
浅谈Laravel中使用Slack进行异常通知
2021/05/29 PHP