使用TFRecord存取多个数据案例


Posted in Python onFebruary 17, 2020

TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性。

TFRecord格式

TFRecord文件中的数据都是用tf.train.Example Protocol Buffer的格式来存储的,tf.train.Example可以被定义为:

message Example{
  Features features = 1
}

message Features{
  map<string, Feature> feature = 1
}

message Feature{
  oneof kind{
    BytesList bytes_list = 1
    FloatList float_list = 1
    Int64List int64_list = 1
  }
}

可以看出Example是一个嵌套的数据结构,其中属性名称可以为一个字符串,其取值可以是字符串BytesList、实数列表FloatList或整数列表Int64List。

将数据转化为TFRecord格式

以下代码是将MNIST输入数据转化为TFRecord格式:

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


# 生成整数型的属性
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 生成浮点型的属性
def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))  
#若想保存为数组,则要改成value=value即可


# 生成字符串型的属性
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


mnist = input_data.read_data_sets("/tensorflow_google", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
# 训练数据所对应的正确答案,可以作为一个属性保存在TFRecord中
labels = mnist.train.labels
# 训练数据的图像分辨率,这可以作为Example中的一个属性
pixels = images.shape[1]
num_examples = mnist.train.num_examples

# 输出TFRecord文件的地址
filename = "/tensorflow_google/mnist_output.tfrecords"
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
  # 将图像矩阵转换成一个字符串
  image_raw = images[index].tostring()
  # 将一个样例转化为Example Protocol Buffer, 并将所有的信息写入这个数据结构
  example = tf.train.Example(features=tf.train.Features(feature={
    'pixels': _int64_feature(pixels),
    'label': _int64_feature(np.argmax(labels[index])),
    'image_raw': _bytes_feature(image_raw)}))

  # 将一个Example写入TFRecord文件
  writer.write(example.SerializeToString())
writer.close()

本程序将MNIST数据集中所有的训练数据存储到了一个TFRecord文件中,若数据量较大,也可以存入多个文件。

从TFRecord文件中读取数据

以下代码可以从上面代码中的TFRecord中读取单个或多个训练数据:

# -*- coding: utf-8 -*-
import tensorflow as tf

# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(["/Users/gaoyue/文档/Program/tensorflow_google/chapter7"
                         "/mnist_output.tfrecords"])

# 从文件中读出一个样例,也可以使用read_up_to函数一次性读取多个样例
# _, serialized_example = reader.read(filename_queue)
_, serialized_example = reader.read_up_to(filename_queue, 6) #读取6个样例
# 解析读入的一个样例,如果需要解析多个样例,可以用parse_example函数
# features = tf.parse_single_example(serialized_example, features={
# 解析多个样例
features = tf.parse_example(serialized_example, features={
  # TensorFlow提供两种不同的属性解析方法
  # 第一种是tf.FixedLenFeature,得到的解析结果为Tensor
  # 第二种是tf.VarLenFeature,得到的解析结果为SparseTensor,用于处理稀疏数据
  # 解析数据的格式需要与写入数据的格式一致
  'image_raw': tf.FixedLenFeature([], tf.string),
  'pixels': tf.FixedLenFeature([], tf.int64),
  'label': tf.FixedLenFeature([], tf.int64),
})

# tf.decode_raw可以将字符串解析成图像对应的像素数组
images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 每次运行可以读取TFRecord中的一个样例,当所有样例都读完之后,会重头读取
# for i in range(10):
#   image, label, pixel = sess.run([images, labels, pixels])
#   # print(image, label, pixel)
#   print(label, pixel)

# 读取TFRecord中的前6个样例,若加入循环,则会每次从上次输出的地方继续顺序读6个样例
image, label, pixel = sess.run([images, labels, pixels])
print(label, pixel)

sess.close()

>> [7 3 4 6 1 8] [784 784 784 784 784 784]

输出结果显示,从TFRecord文件中顺序读出前6个样例。

以上这篇使用TFRecord存取多个数据案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现微信公众平台自定义菜单实例
Mar 20 Python
Python中json格式数据的编码与解码方法详解
Jul 01 Python
python微信跳一跳系列之色块轮廓定位棋盘
Feb 26 Python
python中找出numpy array数组的最值及其索引方法
Apr 17 Python
python实现汉诺塔算法
Mar 01 Python
Pycharm2017版本设置启动时默认自动打开项目的方法
Oct 29 Python
Python 利用切片从列表中取出一部分使用的方法
Feb 01 Python
python实现广度优先搜索过程解析
Oct 19 Python
python 导入数据及作图的实现
Dec 03 Python
python mysql自增字段AUTO_INCREMENT值的修改方式
May 18 Python
python中Django文件上传方法详解
Aug 05 Python
详解python使用金山词霸的翻译功能(调试工具断点的使用)
Jan 07 Python
从多个tfrecord文件中无限读取文件的例子
Feb 17 #Python
Python3连接Mysql8.0遇到的问题及处理步骤
Feb 17 #Python
python3连接MySQL8.0的两种方式
Feb 17 #Python
Win10下安装并使用tensorflow-gpu1.8.0+python3.6全过程分析(显卡MX250+CUDA9.0+cudnn)
Feb 17 #Python
Windows下实现将Pascal VOC转化为TFRecords
Feb 17 #Python
tensorflow生成多个tfrecord文件实例
Feb 17 #Python
tensorflow将图片保存为tfrecord和tfrecord的读取方式
Feb 17 #Python
You might like
我的论坛源代码(二)
2006/10/09 PHP
php查看session内容的函数
2008/08/27 PHP
浏览器预览PHP文件时顶部出现空白影响布局分析原因及解决办法
2013/01/11 PHP
自制PHP框架之设计模式
2017/05/07 PHP
YII2框架中excel表格导出的方法详解
2017/07/21 PHP
Laravel框架用户登陆身份验证实现方法详解
2017/09/14 PHP
一段利用WSH获取登录时间的jscript代码
2008/05/11 Javascript
JS获取地址栏参数的小例子
2013/08/23 Javascript
详解Nodejs之静态资源处理
2017/06/05 NodeJs
Vue的elementUI实现自定义主题方法
2018/02/23 Javascript
webpack组织模块打包Library的原理及实现
2018/03/10 Javascript
vue用Object.defineProperty手写一个简单的双向绑定的示例
2018/07/09 Javascript
图文讲解vue的v-if使用方法
2019/02/11 Javascript
js原生map实现的方法总结
2020/01/19 Javascript
JQuery插件tablesorter表格排序实现过程解析
2020/05/28 jQuery
[03:00]DOTA2-DPC中国联赛1月18日Recap集锦
2021/03/11 DOTA
Python基于二分查找实现求整数平方根的方法
2016/05/12 Python
python中数据爬虫requests库使用方法详解
2018/02/11 Python
Python 中包/模块的 `import` 操作代码
2019/04/22 Python
Python 通过微信控制实现app定位发送到个人服务器再转发微信服务器接收位置信息
2019/08/05 Python
Python基本类型的连接组合和互相转换方式(13种)
2019/12/16 Python
Python基础之字符串常见操作经典实例详解
2020/02/26 Python
Pytorch 高效使用GPU的操作
2020/06/27 Python
python中Django文件上传方法详解
2020/08/05 Python
matplotlib常见函数之plt.rcParams、matshow的使用(坐标轴设置)
2021/01/05 Python
白兰氏健康Mall:BRAND’S
2017/11/13 全球购物
Lacoste(法国鳄鱼)加拿大官网:以标志性的POLO衫而闻名
2019/05/15 全球购物
如何写出高性能的JSP和Servlet
2013/01/22 面试题
自我评价如何写好?
2014/01/05 职场文书
运动会广播稿200字(10篇)
2014/10/12 职场文书
四年级数学上册教学计划
2015/01/20 职场文书
委托书格式范文
2015/01/28 职场文书
公司员工手册范本
2015/05/14 职场文书
警示教育观后感
2015/06/17 职场文书
2016年国培心得体会及反思
2016/01/13 职场文书
gateway与spring-boot-starter-web冲突问题的解决
2021/07/16 Java/Android