使用TFRecord存取多个数据案例


Posted in Python onFebruary 17, 2020

TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性。

TFRecord格式

TFRecord文件中的数据都是用tf.train.Example Protocol Buffer的格式来存储的,tf.train.Example可以被定义为:

message Example{
  Features features = 1
}

message Features{
  map<string, Feature> feature = 1
}

message Feature{
  oneof kind{
    BytesList bytes_list = 1
    FloatList float_list = 1
    Int64List int64_list = 1
  }
}

可以看出Example是一个嵌套的数据结构,其中属性名称可以为一个字符串,其取值可以是字符串BytesList、实数列表FloatList或整数列表Int64List。

将数据转化为TFRecord格式

以下代码是将MNIST输入数据转化为TFRecord格式:

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


# 生成整数型的属性
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 生成浮点型的属性
def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))  
#若想保存为数组,则要改成value=value即可


# 生成字符串型的属性
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


mnist = input_data.read_data_sets("/tensorflow_google", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
# 训练数据所对应的正确答案,可以作为一个属性保存在TFRecord中
labels = mnist.train.labels
# 训练数据的图像分辨率,这可以作为Example中的一个属性
pixels = images.shape[1]
num_examples = mnist.train.num_examples

# 输出TFRecord文件的地址
filename = "/tensorflow_google/mnist_output.tfrecords"
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
  # 将图像矩阵转换成一个字符串
  image_raw = images[index].tostring()
  # 将一个样例转化为Example Protocol Buffer, 并将所有的信息写入这个数据结构
  example = tf.train.Example(features=tf.train.Features(feature={
    'pixels': _int64_feature(pixels),
    'label': _int64_feature(np.argmax(labels[index])),
    'image_raw': _bytes_feature(image_raw)}))

  # 将一个Example写入TFRecord文件
  writer.write(example.SerializeToString())
writer.close()

本程序将MNIST数据集中所有的训练数据存储到了一个TFRecord文件中,若数据量较大,也可以存入多个文件。

从TFRecord文件中读取数据

以下代码可以从上面代码中的TFRecord中读取单个或多个训练数据:

# -*- coding: utf-8 -*-
import tensorflow as tf

# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(["/Users/gaoyue/文档/Program/tensorflow_google/chapter7"
                         "/mnist_output.tfrecords"])

# 从文件中读出一个样例,也可以使用read_up_to函数一次性读取多个样例
# _, serialized_example = reader.read(filename_queue)
_, serialized_example = reader.read_up_to(filename_queue, 6) #读取6个样例
# 解析读入的一个样例,如果需要解析多个样例,可以用parse_example函数
# features = tf.parse_single_example(serialized_example, features={
# 解析多个样例
features = tf.parse_example(serialized_example, features={
  # TensorFlow提供两种不同的属性解析方法
  # 第一种是tf.FixedLenFeature,得到的解析结果为Tensor
  # 第二种是tf.VarLenFeature,得到的解析结果为SparseTensor,用于处理稀疏数据
  # 解析数据的格式需要与写入数据的格式一致
  'image_raw': tf.FixedLenFeature([], tf.string),
  'pixels': tf.FixedLenFeature([], tf.int64),
  'label': tf.FixedLenFeature([], tf.int64),
})

# tf.decode_raw可以将字符串解析成图像对应的像素数组
images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 每次运行可以读取TFRecord中的一个样例,当所有样例都读完之后,会重头读取
# for i in range(10):
#   image, label, pixel = sess.run([images, labels, pixels])
#   # print(image, label, pixel)
#   print(label, pixel)

# 读取TFRecord中的前6个样例,若加入循环,则会每次从上次输出的地方继续顺序读6个样例
image, label, pixel = sess.run([images, labels, pixels])
print(label, pixel)

sess.close()

>> [7 3 4 6 1 8] [784 784 784 784 784 784]

输出结果显示,从TFRecord文件中顺序读出前6个样例。

以上这篇使用TFRecord存取多个数据案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
把MySQL表结构映射为Python中的对象的教程
Apr 07 Python
Python记录详细调用堆栈日志的方法
May 05 Python
python在linux系统下获取系统内存使用情况的方法
May 11 Python
Linux下将Python的Django项目部署到Apache服务器
Dec 24 Python
python爬虫框架scrapy实现模拟登录操作示例
Aug 02 Python
python定间隔取点(np.linspace)的实现
Nov 27 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 Python
使用TensorFlow搭建一个全连接神经网络教程
Feb 06 Python
关于Django Models CharField 参数说明
Mar 31 Python
使用python处理题库表格并转化为word形式的实现
Apr 14 Python
Python grpc超时机制代码示例
Sep 14 Python
从多个tfrecord文件中无限读取文件的例子
Feb 17 #Python
Python3连接Mysql8.0遇到的问题及处理步骤
Feb 17 #Python
python3连接MySQL8.0的两种方式
Feb 17 #Python
Win10下安装并使用tensorflow-gpu1.8.0+python3.6全过程分析(显卡MX250+CUDA9.0+cudnn)
Feb 17 #Python
Windows下实现将Pascal VOC转化为TFRecords
Feb 17 #Python
tensorflow生成多个tfrecord文件实例
Feb 17 #Python
tensorflow将图片保存为tfrecord和tfrecord的读取方式
Feb 17 #Python
You might like
php输出1000以内质数(素数)示例
2014/02/16 PHP
php each 返回数组中当前的键值对并将数组指针向前移动一步实例
2016/11/22 PHP
php抽象方法和抽象类实例分析
2016/12/07 PHP
dojo 之基础篇(三)之向服务器发送数据
2007/03/24 Javascript
JavaScript 拾碎[三] 使用className属性
2010/10/16 Javascript
JS取得绝对路径的实现代码
2015/01/16 Javascript
jQuery实现节点的追加、替换、删除、复制功能示例
2017/07/11 jQuery
JS中关于正则的巧妙操作
2017/08/31 Javascript
Nodejs模块载入运行原理
2018/02/23 NodeJs
vue 权限认证token的实现方法
2018/07/17 Javascript
JavaScript对JSON数组简单排序操作示例
2019/01/31 Javascript
深度了解vue.js中hooks的相关知识
2019/06/14 Javascript
Node快速切换版本、版本回退(降级)、版本更新(升级)
2021/01/07 Javascript
python ElementTree 基本读操作示例
2009/04/09 Python
详解Python2.x中对Unicode编码的使用
2015/04/03 Python
numpy排序与集合运算用法示例
2017/12/15 Python
python实现xlsx文件分析详解
2018/01/02 Python
Django + Uwsgi + Nginx 实现生产环境部署的方法
2018/06/20 Python
详解python中自定义超时异常的几种方法
2019/07/29 Python
Django rstful登陆认证并检查session是否过期代码实例
2019/08/13 Python
opencv转换颜色空间更改图片背景
2019/08/20 Python
解决Python计算矩阵乘向量,矩阵乘实数的一些小错误
2019/08/26 Python
Python 爬虫的原理
2020/07/30 Python
python openCV实现摄像头获取人脸图片
2020/08/20 Python
Python常用扩展插件使用教程解析
2020/11/02 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
2021/02/03 Python
突袭HTML5之Javascript API扩展3—本地存储全新体验
2013/01/31 HTML / CSS
美国知名保健品网站:LuckyVitamin(支持中文)
2017/08/09 全球购物
为什么要优先使用同步代码块而不是同步方法?
2013/01/30 面试题
写一个函数返回1+2+3+…+n的值(假定结果不会超过长整型变量的范围)
2014/09/05 面试题
白莲教口号
2014/06/18 职场文书
父母教会我观后感
2015/06/17 职场文书
2016暑期政治学习心得体会
2016/01/23 职场文书
导游词之上海东方明珠塔
2019/09/25 职场文书
Nginx+SpringBoot实现负载均衡的示例
2021/03/31 Servers
python随机打印成绩排名表
2021/06/23 Python