Tensorflow之构建自己的图片数据集TFrecords的方法


Posted in Python onFebruary 07, 2018

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

流程是:制作数据集—读取数据集—-加入队列

先贴完整的代码:

#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = os.getcwd()

classes = {'test','test1','test2'}
#制作二进制数据
def create_record():
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((64, 64))
      img_raw = img.tobytes() #将图片转化为原生bytes
      print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(feature={
          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
          'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
      writer.write(example.SerializeToString())
  writer.close()

data = create_record()

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

if __name__ == '__main__':
  if 0:
    data = create_record("train.tfrecords")
  else:
    img, label = read_and_decode("train.tfrecords")
    print "tengxing",img,label
    #使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
    # shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
    # 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
    # shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
    # Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                          batch_size=4, capacity=2000,
                          min_after_dequeue=1000)

    # 初始化所有的op
    init = tf.initialize_all_variables()

    with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

制作数据集

#制作二进制数据
def create_record():
  cwd = os.getcwd()
  classes = {'1','2','3'}
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((28, 28))
      img_raw = img.tobytes() #将图片转化为原生bytes
      #print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
          }
        )
      )
      writer.write(example.SerializeToString())
  writer.close()

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

读取数据集

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

加入队列

with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

这样就可以的到和tensorflow官方的二进制数据集了,

注意:

  1. 启动队列那条code不要忘记,不然卡死
  2. 使用的时候记得使用val和l,不然会报类型错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
  3. 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  4. 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
  5. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量

实例2:将图片文件夹下的图片转存tfrecords的数据集。

############################################################################################ 
#!/usr/bin/python2.7 
# -*- coding: utf-8 -*- 
#Author : zhaoqinghui 
#Date  : 2016.5.10 
#Function: image convert to tfrecords  
############################################################################################# 
 
import tensorflow as tf 
import numpy as np 
import cv2 
import os 
import os.path 
from PIL import Image 
 
#参数设置 
############################################################################################### 
train_file = 'train.txt' #训练图片 
name='train'   #生成train.tfrecords 
output_directory='./tfrecords' 
resize_height=32 #存储图片高度 
resize_width=32 #存储图片宽度 
############################################################################################### 
def _int64_feature(value): 
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
def load_file(examples_list_file): 
  lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')]) 
  examples = [] 
  labels = [] 
  for example, label in lines: 
    examples.append(example) 
    labels.append(label) 
  return np.asarray(examples), np.asarray(labels), len(lines) 
 
def extract_image(filename, resize_height, resize_width): 
  image = cv2.imread(filename) 
  image = cv2.resize(image, (resize_height, resize_width)) 
  b,g,r = cv2.split(image)     
  rgb_image = cv2.merge([r,g,b])    
  return rgb_image 
 
def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width): 
  if not os.path.exists(output_directory) or os.path.isfile(output_directory): 
    os.makedirs(output_directory) 
  _examples, _labels, examples_num = load_file(train_file) 
  filename = output_directory + "/" + name + '.tfrecords' 
  writer = tf.python_io.TFRecordWriter(filename) 
  for i, [example, label] in enumerate(zip(_examples, _labels)): 
    print('No.%d' % (i)) 
    image = extract_image(example, resize_height, resize_width) 
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label)) 
    image_raw = image.tostring() 
    example = tf.train.Example(features=tf.train.Features(feature={ 
      'image_raw': _bytes_feature(image_raw), 
      'height': _int64_feature(image.shape[0]), 
      'width': _int64_feature(image.shape[1]), 
      'depth': _int64_feature(image.shape[2]), 
      'label': _int64_feature(label) 
    })) 
    writer.write(example.SerializeToString()) 
  writer.close() 
 
def disp_tfrecords(tfrecord_list_file): 
  filename_queue = tf.train.string_input_producer([tfrecord_list_file]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
 features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  #print(repr(image)) 
  height = features['height'] 
  width = features['width'] 
  depth = features['depth'] 
  label = tf.cast(features['label'], tf.int32) 
  init_op = tf.initialize_all_variables() 
  resultImg=[] 
  resultLabel=[] 
  with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    for i in range(21): 
      image_eval = image.eval() 
      resultLabel.append(label.eval()) 
      image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()]) 
      resultImg.append(image_eval_reshape) 
      pilimg = Image.fromarray(np.asarray(image_eval_reshape)) 
      pilimg.show() 
    coord.request_stop() 
    coord.join(threads) 
    sess.close() 
  return resultImg,resultLabel 
 
def read_tfrecord(filename_queuetemp): 
  filename_queue = tf.train.string_input_producer([filename_queuetemp]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
    features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  # image 
  tf.reshape(image, [256, 256, 3]) 
  # normalize 
  image = tf.cast(image, tf.float32) * (1. /255) - 0.5 
  # label 
  label = tf.cast(features['label'], tf.int32) 
  return image, label 
 
def test(): 
  transform2tfrecord(train_file, name , output_directory, resize_height, resize_width) #转化函数   
  img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数 
  img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数 
  print label 
 
if __name__ == '__main__': 
  test()

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现2014火车票查询代码分享
Jan 10 Python
python解析发往本机的数据包示例 (解析数据包)
Jan 16 Python
仅用50行代码实现一个Python编写的计算器的教程
Apr 17 Python
使用Python求解最大公约数的实现方法
Aug 20 Python
浅谈Python 集合(set)类型的操作——并交差
Jun 30 Python
使用pip发布Python程序的方法步骤
Oct 11 Python
Python进阶:生成器 懒人版本的迭代器详解
Jun 29 Python
python实现两张图片拼接为一张图片并保存
Jul 16 Python
python redis连接 有序集合去重的代码
Aug 04 Python
Python中list循环遍历删除数据的正确方法
Sep 02 Python
python读文件的步骤
Oct 08 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
python深度优先搜索和广度优先搜索
Feb 07 #Python
Python Flask基础教程示例代码
Feb 07 #Python
Python装饰器用法实例总结
Feb 07 #Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 #Python
Python自定义线程池实现方法分析
Feb 07 #Python
使用apidoc管理RESTful风格Flask项目接口文档方法
Feb 07 #Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 #Python
You might like
PHP 内存缓存加速功能memcached安装与用法
2009/09/03 PHP
如何使用php实现评委评分器
2015/07/31 PHP
PHP数据库表操作的封装类及用法实例详解
2016/07/12 PHP
在Thinkphp中使用ajax实现无刷新分页的方法
2016/10/25 PHP
dojo 之基础篇(三)之向服务器发送数据
2007/03/24 Javascript
仅IE不支持setTimeout/setInterval函数的第三个以上参数
2011/05/25 Javascript
使用js操作cookie的一点小收获分享
2013/09/03 Javascript
Google Dart编程语法和基本类型学习教程
2013/11/27 Javascript
关闭浏览器窗口弹出提示框并且可以控制其失效
2014/04/15 Javascript
js 去除字符串第一位逗号的方法
2014/06/07 Javascript
基于javascript的JSON格式页面展示美化方法
2014/07/02 Javascript
JavaScript函数学习总结以及相关的编程习惯指南
2015/11/16 Javascript
Javascript 创建类并动态添加属性及方法的简单实现
2016/10/20 Javascript
利用Js+Css实现折纸动态导航效果实例源码
2017/01/25 Javascript
原生JavaScript实现精美的淘宝轮播图效果示例【附demo源码下载】
2017/05/27 Javascript
详解用vue编写弹出框组件
2017/07/04 Javascript
Vuex实现计数器以及列表展示效果
2018/03/10 Javascript
用vue2.0实现点击选中active其他选项互斥的效果
2018/04/12 Javascript
JavaScript实现封闭区域布尔运算的示例代码
2018/06/25 Javascript
element-ui 时间选择器限制范围的实现(随动)
2019/01/09 Javascript
解决LayUI数据表格复选框不居中显示的问题
2019/09/25 Javascript
jsonp格式前端发送和后台接受写法的代码详解
2019/11/07 Javascript
js刷新页面location.reload()用法详解
2019/12/09 Javascript
基于pycharm导入模块显示不存在的解决方法
2018/10/13 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
2020/03/11 Python
Python字符串格式化常用手段及注意事项
2020/06/17 Python
Python趣味入门教程之循环语句while
2020/08/26 Python
python实现计算图形面积
2021/02/22 Python
用纯css3和html制作泡沫对话框实现代码
2013/03/21 HTML / CSS
香港零食网购:上仓胃子
2020/06/08 全球购物
小学数学国培感言
2014/03/10 职场文书
知名企业招聘广告词大全
2014/03/18 职场文书
2015年复活节活动总结
2015/02/27 职场文书
前台接待员岗位职责
2015/04/15 职场文书
红色电影观后感
2015/06/18 职场文书
公司酒会致辞
2015/07/30 职场文书