Tensorflow使用tfrecord输入数据格式


Posted in Python onJune 19, 2018

Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。

1. TFRecord格式介绍

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
  oneof kind{
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. 将自己的数据转化为TFRecord格式

准备数据

在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.

我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片

转换数据

数据准备好以后,就开始准备生成TFRecord,具体代码如下:

import os 
import tensorflow as tf 
from PIL import Image 
import matplotlib.pyplot as plt 

cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'} 
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords") 

for index,name in enumerate(classes):
  class_path=cwd+name+'/'
  for img_name in os.listdir(class_path): 
    img_path=class_path+img_name 
    img=Image.open(img_path)
    img= img.resize((512,80))
    img_raw=img.tobytes()
    #plt.imshow(img) # if you want to check you image,please delete '#'
    #plt.show()
    example = tf.train.Example(features=tf.train.Features(feature={
      "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
      'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    })) 
    writer.write(example.SerializeToString()) 

writer.close()

3. Tensorflow从TFRecord中读取数据

def read_and_decode(filename): # read iris_contact.tfrecords
  filename_queue = tf.train.string_input_producer([filename])# create a queue

  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)#return file_name and file
  features = tf.parse_single_example(serialized_example,
                    features={
                      'label': tf.FixedLenFeature([], tf.int64),
                      'img_raw' : tf.FixedLenFeature([], tf.string),
                    })#return image and label

  img = tf.decode_raw(features['img_raw'], tf.uint8)
  img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
  label = tf.cast(features['label'], tf.int32) #throw label tensor
  return img, label

4. 将TFRecord中的数据保存为图片

filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"]) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  #return file and file_name
features = tf.parse_single_example(serialized_example,
                  features={
                    'label': tf.FixedLenFeature([], tf.int64),
                    'img_raw' : tf.FixedLenFeature([], tf.string),
                  }) 
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: 
  init_op = tf.initialize_all_variables()
  sess.run(init_op)
  coord=tf.train.Coordinator()
  threads= tf.train.start_queue_runners(coord=coord)
  for i in range(20):
    example, l = sess.run([image,label])#take out image and label
    img=Image.fromarray(example, 'RGB')
    img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
    print(example, l)
  coord.request_stop()
  coord.join(threads)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python的接口测试框架实例
Nov 04 Python
Python安装官方whl包和tar.gz包的方法(推荐)
Jun 04 Python
Django2.1集成xadmin管理后台所遇到的错误集锦(填坑)
Dec 20 Python
python如何实现不用装饰器实现登陆器小程序
Dec 14 Python
Python编程快速上手——Excel到CSV的转换程序案例分析
Feb 28 Python
python GUI库图形界面开发之PyQt5布局控件QVBoxLayout详细使用方法与实例
Mar 06 Python
完美解决ARIMA模型中plot_acf画不出图的问题
Jun 04 Python
Python sqlalchemy时间戳及密码管理实现代码详解
Aug 01 Python
python实现发送带附件的邮件代码分享
Sep 22 Python
Python3.9新特性详解
Oct 10 Python
在 Golang 中实现 Cache::remember 方法详解
Mar 30 Python
Python实现视频中添加音频工具详解
Dec 06 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
详解TensorFlow查看ckpt中变量的几种方法
Jun 19 #Python
TensorFlow 滑动平均的示例代码
Jun 19 #Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
You might like
php图像处理函数大全(推荐收藏)
2013/07/11 PHP
PHP入门教程之表单与验证实例详解
2016/09/11 PHP
Laravel中日期时间处理包Carbon的简单使用
2017/09/21 PHP
详解Laravel5.6 Passport实现Api接口认证
2018/07/27 PHP
IE与FireFox的兼容性问题分析
2007/04/22 Javascript
仅IE9/10同时支持script元素的onload和onreadystatechange事件分析
2011/04/27 Javascript
原生js实现跨浏览器获取鼠标按键的值
2013/04/08 Javascript
javascipt匹配单行和多行注释的正则表达式
2013/11/20 Javascript
JavaScript子类用Object.getPrototypeOf去调用父类方法解析
2013/12/05 Javascript
js获取select标签的值且兼容IE与firefox
2013/12/30 Javascript
基于NodeJS的前后端分离的思考与实践(六)Nginx + Node.js + Java 的软件栈部署实践
2014/09/26 NodeJs
JavaScript获取当前cpu使用率的方法
2015/12/15 Javascript
JavaScript生成.xls文件的代码
2016/12/22 Javascript
利用JavaScript在网页实现八数码启发式A*算法动画效果
2017/04/16 Javascript
jQuery中clone()函数实现表单中增加和减少输入项
2017/05/13 jQuery
详解使用nvm管理多版本node的方法
2017/08/30 Javascript
Vuex 快速入门(简单易懂)
2018/09/20 Javascript
详解vue-router数据加载与缓存使用总结
2018/10/29 Javascript
vue-cli 3.x配置跨域代理的实现方法
2019/04/12 Javascript
Javascript Dom元素获取和添加详解
2019/09/24 Javascript
antd的select下拉框因为数据量太大造成卡顿的解决方式
2020/10/31 Javascript
在python3环境下的Django中使用MySQL数据库的实例
2017/08/29 Python
简述:我为什么选择Python而不是Matlab和R语言
2017/11/14 Python
python 爬虫一键爬取 淘宝天猫宝贝页面主图颜色图和详情图的教程
2018/05/22 Python
python科学计算之numpy——ufunc函数用法
2019/11/25 Python
windows下的pycharm安装及其设置中文菜单
2020/04/23 Python
python中子类与父类的关系基础知识点
2021/02/02 Python
EJB包括(SessionBean,EntityBean)说出他们的生命周期,及如何管理事务的
2015/07/24 面试题
出纳岗位职责范本
2013/12/01 职场文书
学习十八届四中全会精神思想汇报
2014/10/23 职场文书
解除劳动关系协议书2篇
2014/11/28 职场文书
孝女彩金观后感
2015/06/10 职场文书
七一晚会主持词
2015/06/29 职场文书
会计继续教育培训心得体会
2016/01/19 职场文书
Python中glob库实现文件名的匹配
2021/06/18 Python
Python实现照片卡通化
2021/12/06 Python