tensorflow将图片保存为tfrecord和tfrecord的读取方式


Posted in Python onFebruary 17, 2020

tensorflow官方提供了3种方法来读取数据:

预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。

从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。

本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。

项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、预加载数据

a = tf.constant([1,2,3])
  b = tf.constant([4,5,6])
  c = tf.add(a,b)
  with tf.Session() as sess:
    print(sess.run(c))#[5 7 9]

这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。

二、填充数据

通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。

x = tf.placeholder(tf.int16)
  y = tf.placeholder(tf.int16)
  z = tf.add(x,y)
  with tf.Session() as sess:
    print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]}))
    #[5 7 9]

三、从文件读取数据

通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。

在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

1、tfrecord文件的保存

a、参数设置

dataset_dir_path:训练集图片存放的上级目录(train下还有一个train目录用来存放图片),在dog vs cat数据集中,dog和cat类的区别是依靠图片的名称,如果你的数据集通过文件夹的名称来划分图片类标的,可能需要对代码进行部分修改。

label_name_to_num:字符串类标与数字类标的对应关系,在将图片保存为tfrecord文件的时候,需要将字符串转为整数类标0和1,方便后的训练。

label_num_to_name:数字类标与字符串类标的对应关系。

val_size:验证集在训练集中所占的比例,训练集一共有25000张图片,用20000张来训练,5000张来进行验证。

batch_size:在读取tfrecord文件的时候,每次读取图片的数量。

#数据所在的目录路径
dataset_dir_path = "D:/dataset/kaggle/cat_or_dog/train"
#类标名称和数字的对应关系
label_name_to_num = {"cat":0,"dog":1}
label_num_to_name = {value:key for key,value in label_name_to_num.items()}
#设置验证集占整个数据集的比例
val_size = 0.2
batch_size = 1

b、获取训练集所有的图片路径

获取训练目录下所有的dog和cat的图片路径,将它们分开保存,便于后面训练集和验证集数据的划分,保证每类图片在所占的比例相同。

#获取文件所在路径
 dataset_dir = os.path.join(dataset_dir,split_name)
 #遍历目录下的所有图片
 for filename in os.listdir(dataset_dir):
   #获取文件的路径
   file_path = os.path.join(dataset_dir,filename)
   if file_path.endswith("jpg") and os.path.exists(file_path):
     #获取类别的名称
     label_name = filename.split(".")[0]
     if label_name == "cat":
       cat_img_paths.append(file_path)
     elif label_name == "dog":
       dog_img_paths.append(file_path)
 return cat_img_paths,dog_img_paths

c、设置需要保存的图片信息

对于训练集的图片主要保存图片的字节数据、图片的格式、图片的标签、图片的高和宽,测试集保存为tfrecord文件的时候需要保存图片的名称,因为在提交数据的时候需要用到图片的名称信息。在保存图片信息的时候,需要先将这些信息转换为byte数据才能写入到tfrecord文件中。

def int64_feature(values):
 if not isinstance(values, (tuple, list)):
  values = [values]
 return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
 
def bytes_feature(values):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
 
 
#将图片信息转换为tfrecords可以保存的序列化信息
def image_to_tfexample(split_name,image_data, image_format, height, width, img_info):
  '''
  :param split_name: train或val或test
  :param image_data: 图片的二进制数据
  :param image_format: 图片的格式
  :param height: 图片的高
  :param width: 图片的宽
  :param img_info: 图片的标签或图片的名称,当split_name为test时,img_info为图片的名称否则为图片标签
  :return:
  '''
  if split_name == "test":
    return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/img_name': bytes_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))
  else:
     return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/label': int64_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))

d、保存tfrecord文件

主要是通过TFRecordWriter来保存tfrecord文件,在将图片信息保存为tfrecord文件的时候,需要先将图片信息序列化为字符串才能进行写入。ImageReader类可以将图片字节数据解码为指定格式的图片,获取图片的宽和高信息。

_get_dataset_filename函数是通过数据集的名称和split_name的名称来组合获取tfrecord文件的名称,tfrecord名称如下:

tensorflow将图片保存为tfrecord和tfrecord的读取方式

def _convert_tfrecord_dataset(split_name, filenames, label_name_to_id, 
dataset_dir, tfrecord_filename, _NUM_SHARDS):
  '''
  :param split_name:train或val或test
  :param filenames:图片的路径列表
  :param label_name_to_id:标签名与数字标签的对应关系
  :param dataset_dir:数据存放的目录
  :param tfrecord_filename:文件保存的前缀名
  :param _NUM_SHARDS:将整个数据集分为几个文件
  :return:
  '''
  assert split_name in ['train', 'val','test']
  #计算平均每一个tfrecords文件保存多少张图片
  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    image_reader = ImageReader()
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        #获取tfrecord文件的名称
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id,
 tfrecord_filename = tfrecord_filename, _NUM_SHARDS = _NUM_SHARDS)
        #写tfrecords文件
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            #更新控制台中已经完成的图片数量
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
              i+1, len(filenames), shard_id))
            sys.stdout.flush()
            #读取图片,将图片数据读取为bytes
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            #获取图片的高和宽
            height, width = image_reader.read_image_dims(sess, image_data)
            #获取路径中的图片名称
            img_name = os.path.basename(filenames[i])
            if split_name == "test":
              #需要将图片名称转换为二进制
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, img_name.encode())
              tfrecord_writer.write(example.SerializeToString())
            else:
              #获取图片的类别
              class_name = img_name.split(".")[0]
              label_id = label_name_to_id[class_name]
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, label_id)
              tfrecord_writer.write(example.SerializeToString())
        sys.stdout.write('\n')
        sys.stdout.flush()

e、将数据集分为验证集和训练集保存为tfrecord文件

先获取数据集中所有图片的路径和图片的标签信息,将不同类别的图片分为训练集和验证集,并保证训练集和验证集中不同类别的图片数量保持相同,在保存为tfrecord文件之前,打乱所有图片的路径。将训练集分为了2个tfrecord文件,验证集保存为1个tfrecord文件。

#生成tfrecord文件
def generate_tfreocrd():
  #获取目录下所有的猫和狗图片的路径
  cat_img_paths,dog_img_paths = _get_dateset_imgPaths(dataset_dir_path,"train")
  #打乱路径列表的顺序
  np.random.shuffle(cat_img_paths)
  np.random.shuffle(dog_img_paths)
  #计算不同类别验证集所占的图片数量
  cat_val_num = int(len(cat_img_paths) * val_size)
  dog_val_num = int(len(dog_img_paths) * val_size)
  #将所有的图片路径分为训练集和验证集
  train_img_paths = cat_img_paths[cat_val_num:]
  val_img_paths = cat_img_paths[:cat_val_num]
  train_img_paths.extend(dog_img_paths[dog_val_num:])
  val_img_paths.extend(dog_img_paths[:dog_val_num])
  #打乱训练集和验证集的顺序
  np.random.shuffle(train_img_paths)
  np.random.shuffle(val_img_paths)
  #将训练集保存为tfrecord文件
  _convert_tfrecord_dataset("train",train_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",2)
  #将验证集保存为tfrecord文件
  _convert_tfrecord_dataset("val",val_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",1)

通过控制台你能够看到tfrecord文件的保存进度

tensorflow将图片保存为tfrecord和tfrecord的读取方式

2、从tfrecord文件中读取数据

a、读取tfrecord文件,将数据转换为dataset

通过TFRecordReader来读取tfrecord文件,在读取tfrecord文件时需要通过tf.FixedLenFeature来反序列化存储的图片信息,这里我们只读取图片数据和图片的标签,再通过slim模块将图片数据和标签信息存储为一个dataset。

#创建一个tfrecord读文件对象
  reader = tf.TFRecordReader
    keys_to_feature = {
      "image/encoded":tf.FixedLenFeature((),tf.string,default_value=""),
      "image/format":tf.FixedLenFeature((),tf.string,default_value="jpg"),
     "image/label":tf.FixedLenFeature([],tf.int64,default_value=tf.zeros([],tf.int64))
    }
    items_to_handles = {
      "image":slim.tfexample_decoder.Image(),
      "label":slim.tfexample_decoder.Tensor("image/label")
    }
    items_to_descriptions = {
      "image":"a 3-channel RGB image",
      "img_name":"a image label"
    }
    #创建一个tfrecoder解析对象
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_feature,items_to_handles)
    #读取所有的tfrecord文件,创建数据集
    dataset = slim.dataset.Dataset(
      data_sources = tfrecord_paths,
      decoder = decoder,
      reader = reader,
      num_readers = 4,
      num_samples = num_imgs,
      num_classes = num_classes,
      labels_to_name = labels_to_name,
      items_to_descriptions = items_to_descriptions
    )

b、获取batch数据

preprocessing_image对图片进行预处理,对图片进行数据增强,输出后的图片尺寸由height和width参数决定,固定图片的尺寸方便CNN的模型训练。

def load_batch(split_name,dataset,batch_size,height,width):
  data_provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    common_queue_capacity = 24 + 3 * batch_size,
    common_queue_min = 24
  )
    raw_image,img_label = data_provider.get(["image","label"])
    #Perform the correct preprocessing for this image depending if it is training or evaluating
    image = preprocess_image(raw_image, height, width,True)
    #As for the raw images, we just do a simple reshape to batch it up
    raw_image = tf.expand_dims(raw_image, 0)
    raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
    raw_image = tf.squeeze(raw_image)
    #获取一个batch数据
    images,raw_image,labels = tf.train.batch(
      [image,raw_image,img_label],
      batch_size=batch_size,
      num_threads=4,
      capacity=4*batch_size,
      allow_smaller_final_batch=True
    )
    return images,raw_image,labels

c、读取tfrecord文件

#读取tfrecord文件
def read_tfrecord():
  #从tfreocrd文件中读取数据
  train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name)
  images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227)
  with tf.Session() as sess:
    threads = tf.train.start_queue_runners(sess)
    for i in range(6):
      train_img,train_label = sess.run([raw_images,labels])
      plt.subplot(2,3,i+1)
      plt.imshow(np.array(train_img[0]))
      plt.title("image label:%s"%str(label_num_to_name[train_label[0]]))
    plt.show()

读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。

tensorflow将图片保存为tfrecord和tfrecord的读取方式

以上这篇tensorflow将图片保存为tfrecord和tfrecord的读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基础之函数用法实例详解
Sep 10 Python
浅析Python中的join()方法的使用
May 19 Python
Python使用QQ邮箱发送Email的方法实例
Feb 09 Python
浅析Python中return和finally共同挖的坑
Aug 18 Python
python 读取目录下csv文件并绘制曲线v111的方法
Jul 06 Python
Tensorflow加载预训练模型和保存模型的实例
Jul 27 Python
详解将Django部署到Centos7全攻略
Sep 26 Python
Python3模拟curl发送post请求操作示例
May 03 Python
通过PYTHON来实现图像分割详解
Jun 26 Python
FFrpc python客户端lib使用解析
Aug 24 Python
Python3离线安装Requests模块问题
Oct 13 Python
Python getsizeof()和getsize()区分详解
Nov 20 Python
Python 读取有公式cell的结果内容实例方法
Feb 17 #Python
Python破解BiliBili滑块验证码的思路详解(完美避开人机识别)
Feb 17 #Python
Tensorflow 实现将图像与标签数据转化为tfRecord文件
Feb 17 #Python
将自己的数据集制作成TFRecord格式教程
Feb 17 #Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
You might like
基于php 随机数的深入理解
2013/06/05 PHP
php导出excel格式数据问题
2014/03/11 PHP
PHP中$GLOBALS与global的区别详解
2019/03/21 PHP
一些常用的Javascript函数
2006/12/22 Javascript
javascript读取RSS数据
2007/01/20 Javascript
高性能Javascript笔记 数据的存储与访问性能优化
2012/08/02 Javascript
javascript实现禁止复制网页内容
2014/12/16 Javascript
js判断浏览器版本以及浏览器内核的方法
2015/01/20 Javascript
javascript递归回溯法解八皇后问题
2015/04/22 Javascript
浅谈JavaScript 的执行顺序
2015/08/07 Javascript
Ionic3 UI组件之autocomplete详解
2017/06/08 Javascript
JavaScript函数绑定用法实例分析
2017/11/14 Javascript
基于vue cli重构多页面脚手架过程详解
2018/01/23 Javascript
详解vue项目中使用token的身份验证的简单实践
2019/03/08 Javascript
细说webpack6 Babel的使用详解
2019/09/26 Javascript
在Vue中使用Select选择器拼接label的操作
2020/10/22 Javascript
PyMongo安装使用笔记
2015/04/27 Python
Python入门之三角函数tan()函数实例详解
2017/11/08 Python
ubuntu安装mysql pycharm sublime
2018/02/20 Python
python按行读取文件,去掉每行的换行符\n的实例
2018/04/19 Python
python 运用Django 开发后台接口的实例
2018/12/11 Python
Python学习笔记之自定义函数用法详解
2019/06/08 Python
python pandas生成时间列表
2019/06/29 Python
python实现多进程按序号批量修改文件名的方法示例
2019/12/30 Python
利用Python脚本批量生成SQL语句
2020/03/04 Python
用Python自动清理电脑内重复文件,只要10行代码(自动脚本)
2021/01/09 Python
Lookfantastic日本官网:英国知名护肤、化妆品和头发护理购物网站
2018/04/21 全球购物
英国专业美容产品在线:Mylee(从指甲到脱毛)
2020/07/06 全球购物
模具设计与制造专业应届生求职信
2013/10/18 职场文书
生产车间主任的个人自我鉴定
2013/10/25 职场文书
思想道德自我评价2015
2015/03/09 职场文书
人生遥控器观后感
2015/06/11 职场文书
导游词之阳朔遇龙河
2019/12/16 职场文书
你知道哪几种MYSQL的连接查询
2021/06/03 MySQL
css样式important规则的正确使用方式
2022/06/10 HTML / CSS
Java多线程并发FutureTask使用详解
2022/06/28 Java/Android