tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python打开网页和暂停实例
Sep 30 Python
python编码总结(编码类型、格式、转码)
Jul 01 Python
使用python遍历指定城市的一周气温
Mar 31 Python
Python学习笔记之自定义函数用法详解
Jun 08 Python
使用Python opencv实现视频与图片的相互转换
Jul 08 Python
python检测服务器端口代码实例
Aug 31 Python
解决Pycharm的项目目录突然消失的问题
Jan 20 Python
在python里使用await关键字来等另外一个协程的实例
May 04 Python
利用Python函数实现一个万历表完整示例
Jan 23 Python
pytest fixtures装饰器的使用和如何控制用例的执行顺序
Jan 28 Python
使用pandas生成/读取csv文件的方法实例
Jul 09 Python
一篇文章弄懂Python中的内建函数
Aug 07 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
比较简单的百度网盘文件直链PHP代码
2013/03/24 PHP
PHP和JavaScrip分别获取关联数组的键值示例代码
2013/09/16 PHP
Centos 6.5系统下编译安装PHP 7.0.13的方法
2016/12/19 PHP
PHP常用操作类之通信数据封装类的实现
2017/07/16 PHP
PHP自动载入类文件函数__autoload的使用方法
2019/03/25 PHP
文字幻灯片
2006/06/26 Javascript
JavaScript 判断指定字符串是否为有效数字
2010/05/11 Javascript
Javascript的各种节点操作实例演示代码
2012/06/27 Javascript
js实现无限级树形导航列表效果代码
2015/09/23 Javascript
在JavaScript中如何解决用execCommand(
2015/10/19 Javascript
jQuery随手笔记之常用的jQuery操作DOM事件
2015/11/29 Javascript
ReactNative页面跳转实例代码
2016/09/27 Javascript
数组Array的一些方法(总结)
2017/02/17 Javascript
详解使用fetch发送post请求时的参数处理
2017/04/05 Javascript
jQuery实现的简单动态添加、删除表格功能示例
2017/09/21 jQuery
seajs中最常用的7个功能、配置示例
2017/10/10 Javascript
JS实现的base64加密解密操作示例
2018/04/18 Javascript
vue使用jsonp抓取qq音乐数据的方法
2018/06/21 Javascript
如何优雅的在一台vps(云主机)上面部署vue+mongodb+express项目
2019/01/20 Javascript
解决vue中使用proxy配置不同端口和ip接口问题
2019/08/14 Javascript
[02:08:58]2014 DOTA2国际邀请赛中国区预选赛 Ne VS CIS
2014/05/22 DOTA
教你安装python Django(图文)
2013/11/04 Python
详解python的几种标准输出重定向方式
2016/08/15 Python
Python爬虫天气预报实例详解(小白入门)
2018/01/24 Python
python3.6 如何将list存入txt后再读出list的方法
2019/07/02 Python
Java中有几种类型的流?JDK为每种类型的流提供了一些抽象类以供继承,请说出他们分别是哪些类
2012/02/06 面试题
服装厂厂长职责
2013/12/16 职场文书
平民服装店创业计划书
2014/01/17 职场文书
监察建议书格式
2014/05/19 职场文书
北京奥运会口号
2014/06/21 职场文书
2014大学生批评与自我批评思想汇报
2014/09/21 职场文书
工作简报格式范文
2015/07/21 职场文书
党务工作者主要事迹材料
2015/11/03 职场文书
2016年世界人口日宣传活动总结
2016/04/05 职场文书
导游词之重庆钓鱼城
2019/09/19 职场文书
导游词之镇江-金山寺
2019/10/14 职场文书