Tensorflow使用tfrecord输入数据格式


Posted in Python onJune 19, 2018

Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。

1. TFRecord格式介绍

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
  oneof kind{
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. 将自己的数据转化为TFRecord格式

准备数据

在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.

我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片

转换数据

数据准备好以后,就开始准备生成TFRecord,具体代码如下:

import os 
import tensorflow as tf 
from PIL import Image 
import matplotlib.pyplot as plt 

cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'} 
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords") 

for index,name in enumerate(classes):
  class_path=cwd+name+'/'
  for img_name in os.listdir(class_path): 
    img_path=class_path+img_name 
    img=Image.open(img_path)
    img= img.resize((512,80))
    img_raw=img.tobytes()
    #plt.imshow(img) # if you want to check you image,please delete '#'
    #plt.show()
    example = tf.train.Example(features=tf.train.Features(feature={
      "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
      'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    })) 
    writer.write(example.SerializeToString()) 

writer.close()

3. Tensorflow从TFRecord中读取数据

def read_and_decode(filename): # read iris_contact.tfrecords
  filename_queue = tf.train.string_input_producer([filename])# create a queue

  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)#return file_name and file
  features = tf.parse_single_example(serialized_example,
                    features={
                      'label': tf.FixedLenFeature([], tf.int64),
                      'img_raw' : tf.FixedLenFeature([], tf.string),
                    })#return image and label

  img = tf.decode_raw(features['img_raw'], tf.uint8)
  img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
  label = tf.cast(features['label'], tf.int32) #throw label tensor
  return img, label

4. 将TFRecord中的数据保存为图片

filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"]) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  #return file and file_name
features = tf.parse_single_example(serialized_example,
                  features={
                    'label': tf.FixedLenFeature([], tf.int64),
                    'img_raw' : tf.FixedLenFeature([], tf.string),
                  }) 
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: 
  init_op = tf.initialize_all_variables()
  sess.run(init_op)
  coord=tf.train.Coordinator()
  threads= tf.train.start_queue_runners(coord=coord)
  for i in range(20):
    example, l = sess.run([image,label])#take out image and label
    img=Image.fromarray(example, 'RGB')
    img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
    print(example, l)
  coord.request_stop()
  coord.join(threads)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python实现BT种子和磁力链接的相互转换
Nov 09 Python
Python基于matplotlib绘制栈式直方图的方法示例
Aug 09 Python
Python 错误和异常代码详解
Jan 29 Python
解决pycharm无法识别本地site-packages的问题
Oct 13 Python
python 用下标截取字符串的实例
Dec 25 Python
人工神经网络算法知识点总结
Jun 11 Python
python celery分布式任务队列的使用详解
Jul 08 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
YUV转为jpg图像的实现
Dec 09 Python
Django用内置方法实现简单搜索功能的方法
Dec 18 Python
K近邻法(KNN)相关知识总结以及如何用python实现
Jan 28 Python
Python爬虫爬取微博热搜保存为 Markdown 文件的源码
Feb 22 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
详解TensorFlow查看ckpt中变量的几种方法
Jun 19 #Python
TensorFlow 滑动平均的示例代码
Jun 19 #Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
You might like
关于Anemometer图形化显示MySQL慢日志的工具搭建及使用的详细介绍
2020/07/13 PHP
PHP SESSION跨页面传递失败解决方案
2020/12/11 PHP
强大的jquery插件jqeuryUI做网页对话框效果!简单
2011/04/14 Javascript
最新28个很棒的jQuery 教程
2011/05/28 Javascript
JS求平均值的小例子
2013/11/29 Javascript
JavaScript格式化日期时间的方法和自定义格式化函数示例
2014/04/04 Javascript
jquery插件jquery.beforeafter.js实现左右拖拽分隔条对比图片的方法
2015/08/07 Javascript
vue.js国际化 vue-i18n插件的使用详解
2017/07/07 Javascript
IScroll那些事_当内容不足时下拉刷新的解决方法
2017/07/18 Javascript
JavaScript实现微信号随机切换代码
2018/03/09 Javascript
JavaScript对象的特性与实践应用深入详解
2018/12/30 Javascript
JavaScript JSON数据处理全集(小结)
2019/08/15 Javascript
wxpython 最小化到托盘与欢迎图片的实现方法
2014/06/09 Python
Python读取指定目录下指定后缀文件并保存为docx
2017/04/23 Python
Python实现的爬虫刷回复功能示例
2018/06/07 Python
分析经典Python开发工程师面试题
2019/04/08 Python
python用match()函数爬数据方法详解
2019/07/23 Python
基于Python数据结构之递归与回溯搜索
2020/02/26 Python
浅析python 通⽤爬⾍和聚焦爬⾍
2020/09/28 Python
使用HTML5的链接预取功能(link prefetching)给网站提速
2012/12/13 HTML / CSS
5 个强大的HTML5 API 函数推荐
2014/11/19 HTML / CSS
HTML5实现应用程序缓存(Application Cache)
2020/06/16 HTML / CSS
英国著名书店:Foyles
2018/12/01 全球购物
香港彩色隐形眼镜在线商店:Stunninglens(全球免费送货)
2019/05/10 全球购物
Java如何调用外部Exe程序
2015/07/04 面试题
Shell脚本如何向终端输出信息
2014/04/25 面试题
应聘护士自荐信
2013/10/21 职场文书
公民授权委托书
2014/10/15 职场文书
财务工作失职检讨书
2014/11/21 职场文书
监守自盗观后感
2015/06/10 职场文书
总结会主持词
2015/07/02 职场文书
趣味运动会广播稿
2015/08/19 职场文书
2017元旦晚会开幕词
2016/03/03 职场文书
Redis集群的关闭与重启操作
2021/07/07 Redis
Javascript webpack动态import
2022/04/19 Javascript
MySQL数据库 任意ip连接方法
2022/05/20 MySQL