Tensorflow之构建自己的图片数据集TFrecords的方法


Posted in Python onFebruary 07, 2018

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

流程是:制作数据集—读取数据集—-加入队列

先贴完整的代码:

#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = os.getcwd()

classes = {'test','test1','test2'}
#制作二进制数据
def create_record():
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((64, 64))
      img_raw = img.tobytes() #将图片转化为原生bytes
      print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(feature={
          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
          'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
      writer.write(example.SerializeToString())
  writer.close()

data = create_record()

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

if __name__ == '__main__':
  if 0:
    data = create_record("train.tfrecords")
  else:
    img, label = read_and_decode("train.tfrecords")
    print "tengxing",img,label
    #使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
    # shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
    # 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
    # shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
    # Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                          batch_size=4, capacity=2000,
                          min_after_dequeue=1000)

    # 初始化所有的op
    init = tf.initialize_all_variables()

    with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

制作数据集

#制作二进制数据
def create_record():
  cwd = os.getcwd()
  classes = {'1','2','3'}
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((28, 28))
      img_raw = img.tobytes() #将图片转化为原生bytes
      #print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
          }
        )
      )
      writer.write(example.SerializeToString())
  writer.close()

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

读取数据集

#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

加入队列

with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

这样就可以的到和tensorflow官方的二进制数据集了,

注意:

  1. 启动队列那条code不要忘记,不然卡死
  2. 使用的时候记得使用val和l,不然会报类型错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
  3. 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  4. 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
  5. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量

实例2:将图片文件夹下的图片转存tfrecords的数据集。

############################################################################################ 
#!/usr/bin/python2.7 
# -*- coding: utf-8 -*- 
#Author : zhaoqinghui 
#Date  : 2016.5.10 
#Function: image convert to tfrecords  
############################################################################################# 
 
import tensorflow as tf 
import numpy as np 
import cv2 
import os 
import os.path 
from PIL import Image 
 
#参数设置 
############################################################################################### 
train_file = 'train.txt' #训练图片 
name='train'   #生成train.tfrecords 
output_directory='./tfrecords' 
resize_height=32 #存储图片高度 
resize_width=32 #存储图片宽度 
############################################################################################### 
def _int64_feature(value): 
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
 
def _bytes_feature(value): 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
 
def load_file(examples_list_file): 
  lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')]) 
  examples = [] 
  labels = [] 
  for example, label in lines: 
    examples.append(example) 
    labels.append(label) 
  return np.asarray(examples), np.asarray(labels), len(lines) 
 
def extract_image(filename, resize_height, resize_width): 
  image = cv2.imread(filename) 
  image = cv2.resize(image, (resize_height, resize_width)) 
  b,g,r = cv2.split(image)     
  rgb_image = cv2.merge([r,g,b])    
  return rgb_image 
 
def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width): 
  if not os.path.exists(output_directory) or os.path.isfile(output_directory): 
    os.makedirs(output_directory) 
  _examples, _labels, examples_num = load_file(train_file) 
  filename = output_directory + "/" + name + '.tfrecords' 
  writer = tf.python_io.TFRecordWriter(filename) 
  for i, [example, label] in enumerate(zip(_examples, _labels)): 
    print('No.%d' % (i)) 
    image = extract_image(example, resize_height, resize_width) 
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label)) 
    image_raw = image.tostring() 
    example = tf.train.Example(features=tf.train.Features(feature={ 
      'image_raw': _bytes_feature(image_raw), 
      'height': _int64_feature(image.shape[0]), 
      'width': _int64_feature(image.shape[1]), 
      'depth': _int64_feature(image.shape[2]), 
      'label': _int64_feature(label) 
    })) 
    writer.write(example.SerializeToString()) 
  writer.close() 
 
def disp_tfrecords(tfrecord_list_file): 
  filename_queue = tf.train.string_input_producer([tfrecord_list_file]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
 features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  #print(repr(image)) 
  height = features['height'] 
  width = features['width'] 
  depth = features['depth'] 
  label = tf.cast(features['label'], tf.int32) 
  init_op = tf.initialize_all_variables() 
  resultImg=[] 
  resultLabel=[] 
  with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    for i in range(21): 
      image_eval = image.eval() 
      resultLabel.append(label.eval()) 
      image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()]) 
      resultImg.append(image_eval_reshape) 
      pilimg = Image.fromarray(np.asarray(image_eval_reshape)) 
      pilimg.show() 
    coord.request_stop() 
    coord.join(threads) 
    sess.close() 
  return resultImg,resultLabel 
 
def read_tfrecord(filename_queuetemp): 
  filename_queue = tf.train.string_input_producer([filename_queuetemp]) 
  reader = tf.TFRecordReader() 
  _, serialized_example = reader.read(filename_queue) 
  features = tf.parse_single_example( 
    serialized_example, 
    features={ 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'depth': tf.FixedLenFeature([], tf.int64), 
     'label': tf.FixedLenFeature([], tf.int64) 
   } 
  ) 
  image = tf.decode_raw(features['image_raw'], tf.uint8) 
  # image 
  tf.reshape(image, [256, 256, 3]) 
  # normalize 
  image = tf.cast(image, tf.float32) * (1. /255) - 0.5 
  # label 
  label = tf.cast(features['label'], tf.int32) 
  return image, label 
 
def test(): 
  transform2tfrecord(train_file, name , output_directory, resize_height, resize_width) #转化函数   
  img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数 
  img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数 
  print label 
 
if __name__ == '__main__': 
  test()

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析xml文件操作实例
Oct 05 Python
在Python中使用第三方模块的教程
Apr 27 Python
Python基础篇之初识Python必看攻略
Jun 23 Python
python 使用get_argument获取url query参数
Apr 28 Python
在Python的一段程序中如何使用多次事件循环详解
Sep 07 Python
使用Python读取安卓手机的屏幕分辨率方法
Mar 31 Python
Python简直是万能的,这5大主要用途你一定要知道!(推荐)
Apr 03 Python
python读写csv文件实例代码
Jul 05 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
Aug 06 Python
python 实现提取log文件中的关键句子,并进行统计分析
Dec 24 Python
python读取csv文件指定行的2种方法详解
Feb 13 Python
pytorch 带batch的tensor类型图像显示操作
May 20 Python
python深度优先搜索和广度优先搜索
Feb 07 #Python
Python Flask基础教程示例代码
Feb 07 #Python
Python装饰器用法实例总结
Feb 07 #Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 #Python
Python自定义线程池实现方法分析
Feb 07 #Python
使用apidoc管理RESTful风格Flask项目接口文档方法
Feb 07 #Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 #Python
You might like
超强分页类2.0发布,支持自定义风格,默认4种显示模式
2007/01/02 PHP
PHP获取网站域名和地址的代码
2008/08/17 PHP
MongoDB在PHP中的常用操作小结
2014/02/20 PHP
关于laravel 子查询 & join的使用
2019/10/16 PHP
JavaScript静态的动态
2006/09/18 Javascript
javascript URL编码和解码使用说明
2010/04/12 Javascript
Javascript的setTimeout()使用闭包特性时需要注意的问题
2014/09/23 Javascript
一款基jquery超炫的动画导航菜单可响应单击事件
2014/11/02 Javascript
jQuery中index()方法用法实例
2014/12/27 Javascript
JavaScript动态提示输入框输入字数的方法
2015/07/27 Javascript
jquery实现页面常用的返回顶部效果
2016/03/04 Javascript
一款简单的jQuery图片标注效果附源码下载
2016/03/22 Javascript
Bootstrap3制作自己的导航栏
2016/05/12 Javascript
JS常见简单正则表达式验证功能小结【手机,地址,企业税号,金额,身份证等】
2017/01/22 Javascript
canvas实现图像布局填充功能
2017/02/06 Javascript
js实现数组去重方法及效率?Ρ? target=
2017/02/14 Javascript
js实现延迟加载的几种方法
2017/04/24 Javascript
原生JS获取元素的位置与尺寸实现方法
2017/10/18 Javascript
vue 项目打包通过命令修改 vue-router 模式 修改 API 接口前缀
2018/06/13 Javascript
layui的table单击行勾选checkbox功能方法
2018/08/14 Javascript
在vue-cli 3中给stylus、sass样式传入共享的全局变量
2019/08/12 Javascript
Python3数据库操作包pymysql的操作方法
2018/07/16 Python
python爬虫简单的添加代理进行访问的实现代码
2019/04/04 Python
python3图片文件批量重命名处理
2019/10/31 Python
Jupyter notebook如何修改平台字体
2020/05/13 Python
keras多显卡训练方式
2020/06/10 Python
python 求两个向量的顺时针夹角操作
2021/03/04 Python
CSS3 分类菜单效果
2019/05/27 HTML / CSS
中国茶叶、茶具一站式网上购物商城:醉品茶城
2018/07/03 全球购物
北京一家公司的.net开发工程师笔试题
2012/04/17 面试题
大学生怎样进行自我评价
2013/12/07 职场文书
学校万圣节活动方案
2014/02/13 职场文书
美术教师求职信范文
2015/03/20 职场文书
教师岗位说明书
2015/09/30 职场文书
2016学习依法治国心得体会
2016/01/15 职场文书
Mysql案例刨析事务隔离级别
2021/09/25 MySQL