tensorflow将图片保存为tfrecord和tfrecord的读取方式


Posted in Python onFebruary 17, 2020

tensorflow官方提供了3种方法来读取数据:

预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。

从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。

本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。

项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、预加载数据

a = tf.constant([1,2,3])
  b = tf.constant([4,5,6])
  c = tf.add(a,b)
  with tf.Session() as sess:
    print(sess.run(c))#[5 7 9]

这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。

二、填充数据

通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。

x = tf.placeholder(tf.int16)
  y = tf.placeholder(tf.int16)
  z = tf.add(x,y)
  with tf.Session() as sess:
    print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]}))
    #[5 7 9]

三、从文件读取数据

通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。

在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

1、tfrecord文件的保存

a、参数设置

dataset_dir_path:训练集图片存放的上级目录(train下还有一个train目录用来存放图片),在dog vs cat数据集中,dog和cat类的区别是依靠图片的名称,如果你的数据集通过文件夹的名称来划分图片类标的,可能需要对代码进行部分修改。

label_name_to_num:字符串类标与数字类标的对应关系,在将图片保存为tfrecord文件的时候,需要将字符串转为整数类标0和1,方便后的训练。

label_num_to_name:数字类标与字符串类标的对应关系。

val_size:验证集在训练集中所占的比例,训练集一共有25000张图片,用20000张来训练,5000张来进行验证。

batch_size:在读取tfrecord文件的时候,每次读取图片的数量。

#数据所在的目录路径
dataset_dir_path = "D:/dataset/kaggle/cat_or_dog/train"
#类标名称和数字的对应关系
label_name_to_num = {"cat":0,"dog":1}
label_num_to_name = {value:key for key,value in label_name_to_num.items()}
#设置验证集占整个数据集的比例
val_size = 0.2
batch_size = 1

b、获取训练集所有的图片路径

获取训练目录下所有的dog和cat的图片路径,将它们分开保存,便于后面训练集和验证集数据的划分,保证每类图片在所占的比例相同。

#获取文件所在路径
 dataset_dir = os.path.join(dataset_dir,split_name)
 #遍历目录下的所有图片
 for filename in os.listdir(dataset_dir):
   #获取文件的路径
   file_path = os.path.join(dataset_dir,filename)
   if file_path.endswith("jpg") and os.path.exists(file_path):
     #获取类别的名称
     label_name = filename.split(".")[0]
     if label_name == "cat":
       cat_img_paths.append(file_path)
     elif label_name == "dog":
       dog_img_paths.append(file_path)
 return cat_img_paths,dog_img_paths

c、设置需要保存的图片信息

对于训练集的图片主要保存图片的字节数据、图片的格式、图片的标签、图片的高和宽,测试集保存为tfrecord文件的时候需要保存图片的名称,因为在提交数据的时候需要用到图片的名称信息。在保存图片信息的时候,需要先将这些信息转换为byte数据才能写入到tfrecord文件中。

def int64_feature(values):
 if not isinstance(values, (tuple, list)):
  values = [values]
 return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
 
def bytes_feature(values):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
 
 
#将图片信息转换为tfrecords可以保存的序列化信息
def image_to_tfexample(split_name,image_data, image_format, height, width, img_info):
  '''
  :param split_name: train或val或test
  :param image_data: 图片的二进制数据
  :param image_format: 图片的格式
  :param height: 图片的高
  :param width: 图片的宽
  :param img_info: 图片的标签或图片的名称,当split_name为test时,img_info为图片的名称否则为图片标签
  :return:
  '''
  if split_name == "test":
    return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/img_name': bytes_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))
  else:
     return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/label': int64_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))

d、保存tfrecord文件

主要是通过TFRecordWriter来保存tfrecord文件,在将图片信息保存为tfrecord文件的时候,需要先将图片信息序列化为字符串才能进行写入。ImageReader类可以将图片字节数据解码为指定格式的图片,获取图片的宽和高信息。

_get_dataset_filename函数是通过数据集的名称和split_name的名称来组合获取tfrecord文件的名称,tfrecord名称如下:

tensorflow将图片保存为tfrecord和tfrecord的读取方式

def _convert_tfrecord_dataset(split_name, filenames, label_name_to_id, 
dataset_dir, tfrecord_filename, _NUM_SHARDS):
  '''
  :param split_name:train或val或test
  :param filenames:图片的路径列表
  :param label_name_to_id:标签名与数字标签的对应关系
  :param dataset_dir:数据存放的目录
  :param tfrecord_filename:文件保存的前缀名
  :param _NUM_SHARDS:将整个数据集分为几个文件
  :return:
  '''
  assert split_name in ['train', 'val','test']
  #计算平均每一个tfrecords文件保存多少张图片
  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    image_reader = ImageReader()
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        #获取tfrecord文件的名称
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id,
 tfrecord_filename = tfrecord_filename, _NUM_SHARDS = _NUM_SHARDS)
        #写tfrecords文件
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            #更新控制台中已经完成的图片数量
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
              i+1, len(filenames), shard_id))
            sys.stdout.flush()
            #读取图片,将图片数据读取为bytes
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            #获取图片的高和宽
            height, width = image_reader.read_image_dims(sess, image_data)
            #获取路径中的图片名称
            img_name = os.path.basename(filenames[i])
            if split_name == "test":
              #需要将图片名称转换为二进制
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, img_name.encode())
              tfrecord_writer.write(example.SerializeToString())
            else:
              #获取图片的类别
              class_name = img_name.split(".")[0]
              label_id = label_name_to_id[class_name]
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, label_id)
              tfrecord_writer.write(example.SerializeToString())
        sys.stdout.write('\n')
        sys.stdout.flush()

e、将数据集分为验证集和训练集保存为tfrecord文件

先获取数据集中所有图片的路径和图片的标签信息,将不同类别的图片分为训练集和验证集,并保证训练集和验证集中不同类别的图片数量保持相同,在保存为tfrecord文件之前,打乱所有图片的路径。将训练集分为了2个tfrecord文件,验证集保存为1个tfrecord文件。

#生成tfrecord文件
def generate_tfreocrd():
  #获取目录下所有的猫和狗图片的路径
  cat_img_paths,dog_img_paths = _get_dateset_imgPaths(dataset_dir_path,"train")
  #打乱路径列表的顺序
  np.random.shuffle(cat_img_paths)
  np.random.shuffle(dog_img_paths)
  #计算不同类别验证集所占的图片数量
  cat_val_num = int(len(cat_img_paths) * val_size)
  dog_val_num = int(len(dog_img_paths) * val_size)
  #将所有的图片路径分为训练集和验证集
  train_img_paths = cat_img_paths[cat_val_num:]
  val_img_paths = cat_img_paths[:cat_val_num]
  train_img_paths.extend(dog_img_paths[dog_val_num:])
  val_img_paths.extend(dog_img_paths[:dog_val_num])
  #打乱训练集和验证集的顺序
  np.random.shuffle(train_img_paths)
  np.random.shuffle(val_img_paths)
  #将训练集保存为tfrecord文件
  _convert_tfrecord_dataset("train",train_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",2)
  #将验证集保存为tfrecord文件
  _convert_tfrecord_dataset("val",val_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",1)

通过控制台你能够看到tfrecord文件的保存进度

tensorflow将图片保存为tfrecord和tfrecord的读取方式

2、从tfrecord文件中读取数据

a、读取tfrecord文件,将数据转换为dataset

通过TFRecordReader来读取tfrecord文件,在读取tfrecord文件时需要通过tf.FixedLenFeature来反序列化存储的图片信息,这里我们只读取图片数据和图片的标签,再通过slim模块将图片数据和标签信息存储为一个dataset。

#创建一个tfrecord读文件对象
  reader = tf.TFRecordReader
    keys_to_feature = {
      "image/encoded":tf.FixedLenFeature((),tf.string,default_value=""),
      "image/format":tf.FixedLenFeature((),tf.string,default_value="jpg"),
     "image/label":tf.FixedLenFeature([],tf.int64,default_value=tf.zeros([],tf.int64))
    }
    items_to_handles = {
      "image":slim.tfexample_decoder.Image(),
      "label":slim.tfexample_decoder.Tensor("image/label")
    }
    items_to_descriptions = {
      "image":"a 3-channel RGB image",
      "img_name":"a image label"
    }
    #创建一个tfrecoder解析对象
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_feature,items_to_handles)
    #读取所有的tfrecord文件,创建数据集
    dataset = slim.dataset.Dataset(
      data_sources = tfrecord_paths,
      decoder = decoder,
      reader = reader,
      num_readers = 4,
      num_samples = num_imgs,
      num_classes = num_classes,
      labels_to_name = labels_to_name,
      items_to_descriptions = items_to_descriptions
    )

b、获取batch数据

preprocessing_image对图片进行预处理,对图片进行数据增强,输出后的图片尺寸由height和width参数决定,固定图片的尺寸方便CNN的模型训练。

def load_batch(split_name,dataset,batch_size,height,width):
  data_provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    common_queue_capacity = 24 + 3 * batch_size,
    common_queue_min = 24
  )
    raw_image,img_label = data_provider.get(["image","label"])
    #Perform the correct preprocessing for this image depending if it is training or evaluating
    image = preprocess_image(raw_image, height, width,True)
    #As for the raw images, we just do a simple reshape to batch it up
    raw_image = tf.expand_dims(raw_image, 0)
    raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
    raw_image = tf.squeeze(raw_image)
    #获取一个batch数据
    images,raw_image,labels = tf.train.batch(
      [image,raw_image,img_label],
      batch_size=batch_size,
      num_threads=4,
      capacity=4*batch_size,
      allow_smaller_final_batch=True
    )
    return images,raw_image,labels

c、读取tfrecord文件

#读取tfrecord文件
def read_tfrecord():
  #从tfreocrd文件中读取数据
  train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name)
  images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227)
  with tf.Session() as sess:
    threads = tf.train.start_queue_runners(sess)
    for i in range(6):
      train_img,train_label = sess.run([raw_images,labels])
      plt.subplot(2,3,i+1)
      plt.imshow(np.array(train_img[0]))
      plt.title("image label:%s"%str(label_num_to_name[train_label[0]]))
    plt.show()

读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。

tensorflow将图片保存为tfrecord和tfrecord的读取方式

以上这篇tensorflow将图片保存为tfrecord和tfrecord的读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 正则表达式(转义问题)
Dec 15 Python
由Python运算π的值深入Python中科学计算的实现
Apr 17 Python
在Python的web框架中中编写日志列表的教程
Apr 30 Python
分享python数据统计的一些小技巧
Jul 21 Python
Python中运算符"=="和"is"的详解
Oct 08 Python
python Pygame的具体使用讲解
Nov 03 Python
python爬虫之urllib,伪装,超时设置,异常处理的方法
Dec 19 Python
简单了解python中的与或非运算
Sep 18 Python
浅析python redis的连接及相关操作
Nov 07 Python
python 链接sqlserver 写接口实例
Mar 11 Python
关于PySnooper 永远不要使用print进行调试的问题
Mar 04 Python
Python作用域和名称空间的详细介绍
Apr 13 Python
Python 读取有公式cell的结果内容实例方法
Feb 17 #Python
Python破解BiliBili滑块验证码的思路详解(完美避开人机识别)
Feb 17 #Python
Tensorflow 实现将图像与标签数据转化为tfRecord文件
Feb 17 #Python
将自己的数据集制作成TFRecord格式教程
Feb 17 #Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
You might like
让PHP开发者事半功倍的十大技巧小结
2010/04/20 PHP
php数组函数序列之asort() - 对数组的元素值进行升序排序,保持索引关系
2011/11/02 PHP
phpmyadmin出现Cannot start session without errors问题解决方法
2014/08/14 PHP
PHP file_get_contents函数读取远程数据超时的解决方法
2015/05/13 PHP
php+jQuery+Ajax实现点赞效果的方法(附源码下载)
2020/07/21 PHP
PHP实现UTF8二进制及明文字符串的转化功能示例
2017/11/20 PHP
javascript firefox不显示本地预览图片问题的解决方法
2008/11/12 Javascript
AJAX 网页保留浏览器前进后退等功能
2011/02/12 Javascript
javascript 闭包
2011/09/15 Javascript
javascript实现英文首字母大写
2015/04/23 Javascript
详解AngularJS1.x学习directive 中‘& ’‘=’ ‘@’符号的区别使用
2017/08/23 Javascript
node.js 发布订阅模式的实例
2017/09/10 Javascript
Node.JS使用Sequelize操作MySQL的示例代码
2017/10/09 Javascript
详解vuex的简单使用
2018/03/12 Javascript
javascript防抖函数debounce详解
2019/06/11 Javascript
基于Vue el-autocomplete 实现类似百度搜索框功能
2019/10/25 Javascript
D3.js 实现带伸缩时间轴拓扑图的示例代码
2020/01/20 Javascript
vue实现导航标题栏随页面滚动渐隐渐显效果
2020/03/12 Javascript
微信小程序间使用navigator跳转传值问题实例分析
2020/03/27 Javascript
vue 导航守卫和axios拦截器有哪些区别
2020/12/19 Vue.js
[00:30]塑造者的传承礼包-戴泽“暗影之焰”套装展示视频
2014/04/04 DOTA
[01:57]2018DOTA2亚洲邀请赛赛前采访-iG
2018/04/03 DOTA
[49:17]DOTA2-DPC中国联赛 正赛 Phoenix vs Dynasty BO3 第三场 1月26日
2021/03/11 DOTA
Python常用的日期时间处理方法示例
2015/02/08 Python
python3实现短网址和数字相互转换的方法
2015/04/28 Python
python opencv实现证件照换底功能
2019/08/19 Python
Python基于百度AI实现OCR文字识别
2020/04/02 Python
如果有两个类A,B,怎么样才能使A在发生一个事件的时候通知B
2016/03/12 面试题
大学生预备党员自我评价分享
2013/11/16 职场文书
园林资料员岗位职责
2013/12/30 职场文书
中学生学雷锋活动心得体会
2014/03/10 职场文书
医学生职业生涯规划书范文
2014/03/13 职场文书
优秀乡村医生先进事迹材料
2014/08/23 职场文书
毕业实习自我鉴定范文2014
2014/09/26 职场文书
战略性融资合作协议书范本
2014/10/17 职场文书
合作合同协议书范本
2015/01/27 职场文书