tensorflow TFRecords文件的生成和读取的方法


Posted in Python onFebruary 06, 2018

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example { 
  Features features = 1; 
}; 
message Features { 
  map<string, Feature> feature = 1; 
}; 
message Feature { 
  oneof kind { 
  BytesList bytes_list = 1; 
  FloatList float_list = 2; 
  Int64List int64_list = 3; 
} 
};

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle 
import numpy as np 
import glob 
import tensorflow as tf 
import cv2 
import sys 
import os 
 
# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~ 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
 
shuffle_data = True 
image_path = '/path/to/image/*.jpg' 
 
# 取得该路径下所有图片的路径,type(addrs)= list 
addrs = glob.glob(image_path) 
# 标签数据的获得具体情况具体分析,type(labels)= list 
labels = ... 
 
# 这里是打乱数据的顺序 
if shuffle_data: 
  c = list(zip(addrs, labels)) 
  shuffle(c) 
  addrs, labels = zip(*c) 
 
# 按需分割数据集 
train_addrs = addrs[0:int(0.7*len(addrs))] 
train_labels = labels[0:int(0.7*len(labels))] 
 
val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))] 
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 
 
test_addrs = addrs[int(0.9*len(addrs)):] 
test_labels = labels[int(0.9*len(labels)):] 
 
# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片 
def load_image(addr): # A function to Load image 
  img = cv2.imread(addr) 
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
  # 这里/255是为了将像素值归一化到[0,1] 
  img = img / 255. 
  img = img.astype(np.float32) 
  return img 
 
# 将数据转化成对应的属性 
def _int64_feature(value):  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
 
def _float_feature(value): 
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 
 
# 下面这段就开始把数据写入TFRecods文件 
 
train_filename = '/path/to/train.tfrecords' # 输出文件地址 
 
# 创建一个writer来写 TFRecords 文件 
writer = tf.python_io.TFRecordWriter(train_filename) 
 
for i in range(len(train_addrs)): 
  # 这是写入操作可视化处理 
  if not i % 1000: 
    print('Train data: {}/{}'.format(i, len(train_addrs))) 
    sys.stdout.flush() 
  # 加载图片 
  img = load_image(train_addrs[i]) 
 
  label = train_labels[i] 
 
  # 创建一个属性(feature) 
  feature = {'train/label': _int64_feature(label), 
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 
 
  # 创建一个 example protocol buffer 
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 
 
  # 将上面的example protocol buffer写入文件 
  writer.write(example.SerializeToString()) 
 
writer.close() 
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
data_path = 'train.tfrecords' # tfrecords 文件的地址 
 
with tf.Session() as sess: 
  # 先定义feature,这里要和之前创建的时候保持一致 
  feature = { 
    'train/image': tf.FixedLenFeature([], tf.string), 
    'train/label': tf.FixedLenFeature([], tf.int64) 
  } 
  # 创建一个队列来维护输入文件列表 
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 
 
  # 定义一个 reader ,读取下一个 record 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
 
  # 解析读入的一个record 
  features = tf.parse_single_example(serialized_example, features=feature) 
 
  # 将字符串解析成图像对应的像素组 
  image = tf.decode_raw(features['train/image'], tf.float32) 
 
  # 将标签转化成int32 
  label = tf.cast(features['train/label'], tf.int32) 
 
  # 这里将图片还原成原来的维度 
  image = tf.reshape(image, [224, 224, 3]) 
 
  # 你还可以进行其他一些预处理.... 
 
  # 这里是创建顺序随机 batches(函数不懂的自行百度) 
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 
 
  # 初始化 
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
  sess.run(init_op) 
 
  # 启动多线程处理输入数据 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
 
  .... 
 
  #关闭线程 
  coord.request_stop() 
  coord.join(threads) 
  sess.close()

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用cStringIO实现临时内存文件访问的方法
Mar 26 Python
Python中的模块和包概念介绍
Apr 13 Python
Python使用面向对象方式创建线程实现12306售票系统
Dec 24 Python
详解Python中的静态方法与类成员方法
Feb 28 Python
python处理multipart/form-data的请求方法
Dec 26 Python
Python任意字符串转16, 32, 64进制的方法
Jun 12 Python
Python udp网络程序实现发送、接收数据功能示例
Dec 09 Python
python 视频逐帧保存为图片的完整实例
Dec 10 Python
Python调用接口合并Excel表代码实例
Mar 31 Python
PyTorch的torch.cat用法
Jun 28 Python
Pycharm 2020.1 版配置优化的详细教程
Aug 07 Python
pytorch训练神经网络爆内存的解决方案
May 22 Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
You might like
php计算数组不为空元素个数的方法
2014/01/27 PHP
PHP房贷计算器实例代码,等额本息,等额本金
2017/04/01 PHP
TP5(thinkPHP5)框架使用ajax实现与后台数据交互的方法小结
2020/02/10 PHP
JavaScript之HTMLCollection接口代码
2011/04/27 Javascript
MooTools 页面滚动浮动层智能定位实现代码
2011/08/23 Javascript
NodeJs中的非阻塞方法介绍
2012/06/05 NodeJs
深入Javascript函数、递归与闭包(执行环境、变量对象与作用域链)使用详解
2013/05/08 Javascript
html+javascript实现可拖动可提交的弹出层对话框效果
2013/08/05 Javascript
Jquery $.getJSON 在IE下的缓存问题解决方法
2014/10/10 Javascript
jquery 点击元素后,滚动条滚动至该元素位置的方法
2016/08/05 Javascript
js实现文字超出部分用省略号代替实例代码
2016/09/01 Javascript
Node.js中路径处理模块path详解
2016/11/14 Javascript
简单实现JS计算器功能
2016/12/21 Javascript
详解Vue demo实现商品列表的展示
2019/05/07 Javascript
WEEX环境搭建与入门详解
2019/10/16 Javascript
微信小程序利用for循环解决内容变更问题
2020/03/05 Javascript
微信小程序scroll-view隐藏滚动条的方法详解
2020/03/25 Javascript
Python编程之序列操作实例详解
2017/07/22 Python
pandas DataFrame实现几列数据合并成为新的一列方法
2018/06/08 Python
python实现将一个数组逆序输出的方法
2018/06/25 Python
Python Dataframe 指定多列去重、求差集的方法
2018/07/10 Python
Python爬虫实现简单的爬取有道翻译功能示例
2018/07/13 Python
Python 中导入csv数据的三种方法
2018/11/01 Python
python 读取鼠标点击坐标的实例
2018/12/29 Python
python对数组进行排序,并输出排序后对应的索引值方式
2020/02/28 Python
python os模块常用的29种方法使用详解
2020/06/02 Python
基于Python3读写INI配置文件过程解析
2020/07/23 Python
使用CSS3制作饼状旋转载入效果的实例
2015/06/23 HTML / CSS
科尔士百货公司官网:Kohl’s
2016/07/11 全球购物
Baracuta官方网站:Harrington夹克,G9,G4,G10等
2018/03/06 全球购物
美国牙科折扣计划:DentalPlans.com
2019/08/26 全球购物
广告业务员岗位职责
2014/02/06 职场文书
食品安全工作实施方案
2014/03/26 职场文书
高考诚信考试承诺书
2015/04/29 职场文书
人口与计划生育责任书
2015/05/09 职场文书
关于CentOS 8 搭建MongoDB4.4分片集群的问题
2021/10/24 MongoDB