tensorflow将图片保存为tfrecord和tfrecord的读取方式


Posted in Python onFebruary 17, 2020

tensorflow官方提供了3种方法来读取数据:

预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。

从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。

本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。

项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、预加载数据

a = tf.constant([1,2,3])
  b = tf.constant([4,5,6])
  c = tf.add(a,b)
  with tf.Session() as sess:
    print(sess.run(c))#[5 7 9]

这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。

二、填充数据

通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。

x = tf.placeholder(tf.int16)
  y = tf.placeholder(tf.int16)
  z = tf.add(x,y)
  with tf.Session() as sess:
    print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]}))
    #[5 7 9]

三、从文件读取数据

通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。

在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

1、tfrecord文件的保存

a、参数设置

dataset_dir_path:训练集图片存放的上级目录(train下还有一个train目录用来存放图片),在dog vs cat数据集中,dog和cat类的区别是依靠图片的名称,如果你的数据集通过文件夹的名称来划分图片类标的,可能需要对代码进行部分修改。

label_name_to_num:字符串类标与数字类标的对应关系,在将图片保存为tfrecord文件的时候,需要将字符串转为整数类标0和1,方便后的训练。

label_num_to_name:数字类标与字符串类标的对应关系。

val_size:验证集在训练集中所占的比例,训练集一共有25000张图片,用20000张来训练,5000张来进行验证。

batch_size:在读取tfrecord文件的时候,每次读取图片的数量。

#数据所在的目录路径
dataset_dir_path = "D:/dataset/kaggle/cat_or_dog/train"
#类标名称和数字的对应关系
label_name_to_num = {"cat":0,"dog":1}
label_num_to_name = {value:key for key,value in label_name_to_num.items()}
#设置验证集占整个数据集的比例
val_size = 0.2
batch_size = 1

b、获取训练集所有的图片路径

获取训练目录下所有的dog和cat的图片路径,将它们分开保存,便于后面训练集和验证集数据的划分,保证每类图片在所占的比例相同。

#获取文件所在路径
 dataset_dir = os.path.join(dataset_dir,split_name)
 #遍历目录下的所有图片
 for filename in os.listdir(dataset_dir):
   #获取文件的路径
   file_path = os.path.join(dataset_dir,filename)
   if file_path.endswith("jpg") and os.path.exists(file_path):
     #获取类别的名称
     label_name = filename.split(".")[0]
     if label_name == "cat":
       cat_img_paths.append(file_path)
     elif label_name == "dog":
       dog_img_paths.append(file_path)
 return cat_img_paths,dog_img_paths

c、设置需要保存的图片信息

对于训练集的图片主要保存图片的字节数据、图片的格式、图片的标签、图片的高和宽,测试集保存为tfrecord文件的时候需要保存图片的名称,因为在提交数据的时候需要用到图片的名称信息。在保存图片信息的时候,需要先将这些信息转换为byte数据才能写入到tfrecord文件中。

def int64_feature(values):
 if not isinstance(values, (tuple, list)):
  values = [values]
 return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
 
def bytes_feature(values):
 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
 
 
#将图片信息转换为tfrecords可以保存的序列化信息
def image_to_tfexample(split_name,image_data, image_format, height, width, img_info):
  '''
  :param split_name: train或val或test
  :param image_data: 图片的二进制数据
  :param image_format: 图片的格式
  :param height: 图片的高
  :param width: 图片的宽
  :param img_info: 图片的标签或图片的名称,当split_name为test时,img_info为图片的名称否则为图片标签
  :return:
  '''
  if split_name == "test":
    return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/img_name': bytes_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))
  else:
     return tf.train.Example(features=tf.train.Features(feature={
       'image/encoded': bytes_feature(image_data),
       'image/format': bytes_feature(image_format),
       'image/label': int64_feature(img_info),
       'image/height': int64_feature(height),
       'image/width': int64_feature(width),
     }))

d、保存tfrecord文件

主要是通过TFRecordWriter来保存tfrecord文件,在将图片信息保存为tfrecord文件的时候,需要先将图片信息序列化为字符串才能进行写入。ImageReader类可以将图片字节数据解码为指定格式的图片,获取图片的宽和高信息。

_get_dataset_filename函数是通过数据集的名称和split_name的名称来组合获取tfrecord文件的名称,tfrecord名称如下:

tensorflow将图片保存为tfrecord和tfrecord的读取方式

def _convert_tfrecord_dataset(split_name, filenames, label_name_to_id, 
dataset_dir, tfrecord_filename, _NUM_SHARDS):
  '''
  :param split_name:train或val或test
  :param filenames:图片的路径列表
  :param label_name_to_id:标签名与数字标签的对应关系
  :param dataset_dir:数据存放的目录
  :param tfrecord_filename:文件保存的前缀名
  :param _NUM_SHARDS:将整个数据集分为几个文件
  :return:
  '''
  assert split_name in ['train', 'val','test']
  #计算平均每一个tfrecords文件保存多少张图片
  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    image_reader = ImageReader()
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        #获取tfrecord文件的名称
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id,
 tfrecord_filename = tfrecord_filename, _NUM_SHARDS = _NUM_SHARDS)
        #写tfrecords文件
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            #更新控制台中已经完成的图片数量
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
              i+1, len(filenames), shard_id))
            sys.stdout.flush()
            #读取图片,将图片数据读取为bytes
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            #获取图片的高和宽
            height, width = image_reader.read_image_dims(sess, image_data)
            #获取路径中的图片名称
            img_name = os.path.basename(filenames[i])
            if split_name == "test":
              #需要将图片名称转换为二进制
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, img_name.encode())
              tfrecord_writer.write(example.SerializeToString())
            else:
              #获取图片的类别
              class_name = img_name.split(".")[0]
              label_id = label_name_to_id[class_name]
              example = image_to_tfexample(
                split_name,image_data, b'jpg', height, width, label_id)
              tfrecord_writer.write(example.SerializeToString())
        sys.stdout.write('\n')
        sys.stdout.flush()

e、将数据集分为验证集和训练集保存为tfrecord文件

先获取数据集中所有图片的路径和图片的标签信息,将不同类别的图片分为训练集和验证集,并保证训练集和验证集中不同类别的图片数量保持相同,在保存为tfrecord文件之前,打乱所有图片的路径。将训练集分为了2个tfrecord文件,验证集保存为1个tfrecord文件。

#生成tfrecord文件
def generate_tfreocrd():
  #获取目录下所有的猫和狗图片的路径
  cat_img_paths,dog_img_paths = _get_dateset_imgPaths(dataset_dir_path,"train")
  #打乱路径列表的顺序
  np.random.shuffle(cat_img_paths)
  np.random.shuffle(dog_img_paths)
  #计算不同类别验证集所占的图片数量
  cat_val_num = int(len(cat_img_paths) * val_size)
  dog_val_num = int(len(dog_img_paths) * val_size)
  #将所有的图片路径分为训练集和验证集
  train_img_paths = cat_img_paths[cat_val_num:]
  val_img_paths = cat_img_paths[:cat_val_num]
  train_img_paths.extend(dog_img_paths[dog_val_num:])
  val_img_paths.extend(dog_img_paths[:dog_val_num])
  #打乱训练集和验证集的顺序
  np.random.shuffle(train_img_paths)
  np.random.shuffle(val_img_paths)
  #将训练集保存为tfrecord文件
  _convert_tfrecord_dataset("train",train_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",2)
  #将验证集保存为tfrecord文件
  _convert_tfrecord_dataset("val",val_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",1)

通过控制台你能够看到tfrecord文件的保存进度

tensorflow将图片保存为tfrecord和tfrecord的读取方式

2、从tfrecord文件中读取数据

a、读取tfrecord文件,将数据转换为dataset

通过TFRecordReader来读取tfrecord文件,在读取tfrecord文件时需要通过tf.FixedLenFeature来反序列化存储的图片信息,这里我们只读取图片数据和图片的标签,再通过slim模块将图片数据和标签信息存储为一个dataset。

#创建一个tfrecord读文件对象
  reader = tf.TFRecordReader
    keys_to_feature = {
      "image/encoded":tf.FixedLenFeature((),tf.string,default_value=""),
      "image/format":tf.FixedLenFeature((),tf.string,default_value="jpg"),
     "image/label":tf.FixedLenFeature([],tf.int64,default_value=tf.zeros([],tf.int64))
    }
    items_to_handles = {
      "image":slim.tfexample_decoder.Image(),
      "label":slim.tfexample_decoder.Tensor("image/label")
    }
    items_to_descriptions = {
      "image":"a 3-channel RGB image",
      "img_name":"a image label"
    }
    #创建一个tfrecoder解析对象
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_feature,items_to_handles)
    #读取所有的tfrecord文件,创建数据集
    dataset = slim.dataset.Dataset(
      data_sources = tfrecord_paths,
      decoder = decoder,
      reader = reader,
      num_readers = 4,
      num_samples = num_imgs,
      num_classes = num_classes,
      labels_to_name = labels_to_name,
      items_to_descriptions = items_to_descriptions
    )

b、获取batch数据

preprocessing_image对图片进行预处理,对图片进行数据增强,输出后的图片尺寸由height和width参数决定,固定图片的尺寸方便CNN的模型训练。

def load_batch(split_name,dataset,batch_size,height,width):
  data_provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    common_queue_capacity = 24 + 3 * batch_size,
    common_queue_min = 24
  )
    raw_image,img_label = data_provider.get(["image","label"])
    #Perform the correct preprocessing for this image depending if it is training or evaluating
    image = preprocess_image(raw_image, height, width,True)
    #As for the raw images, we just do a simple reshape to batch it up
    raw_image = tf.expand_dims(raw_image, 0)
    raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
    raw_image = tf.squeeze(raw_image)
    #获取一个batch数据
    images,raw_image,labels = tf.train.batch(
      [image,raw_image,img_label],
      batch_size=batch_size,
      num_threads=4,
      capacity=4*batch_size,
      allow_smaller_final_batch=True
    )
    return images,raw_image,labels

c、读取tfrecord文件

#读取tfrecord文件
def read_tfrecord():
  #从tfreocrd文件中读取数据
  train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name)
  images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227)
  with tf.Session() as sess:
    threads = tf.train.start_queue_runners(sess)
    for i in range(6):
      train_img,train_label = sess.run([raw_images,labels])
      plt.subplot(2,3,i+1)
      plt.imshow(np.array(train_img[0]))
      plt.title("image label:%s"%str(label_num_to_name[train_label[0]]))
    plt.show()

读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。

tensorflow将图片保存为tfrecord和tfrecord的读取方式

以上这篇tensorflow将图片保存为tfrecord和tfrecord的读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python科学计算之Pandas详解
Jan 15 Python
详谈Python3 操作系统与路径 模块(os / os.path / pathlib)
Apr 26 Python
Python将一个CSV文件里的数据追加到另一个CSV文件的方法
Jul 04 Python
通过python的matplotlib包将Tensorflow数据进行可视化的方法
Jan 09 Python
python协程gevent案例 爬取斗鱼图片过程解析
Aug 27 Python
Pycharm使用远程linux服务器conda/python环境在本地运行的方法(图解))
Dec 09 Python
Python实现自动访问网页的例子
Feb 21 Python
去除python中的字符串空格的简单方法
Dec 22 Python
Django2.1.7 查询数据返回json格式的实现
Dec 29 Python
理解深度学习之深度学习简介
Apr 14 Python
Python爬虫网络请求之代理服务器和动态Cookies
Apr 12 Python
Pytorch中expand()的使用(扩展某个维度)
Jul 15 Python
Python 读取有公式cell的结果内容实例方法
Feb 17 #Python
Python破解BiliBili滑块验证码的思路详解(完美避开人机识别)
Feb 17 #Python
Tensorflow 实现将图像与标签数据转化为tfRecord文件
Feb 17 #Python
将自己的数据集制作成TFRecord格式教程
Feb 17 #Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
You might like
使用bcompiler对PHP文件进行加密的代码
2010/08/29 PHP
淘宝ip地址查询类分享(利用淘宝ip库)
2014/01/07 PHP
destoon整合UCenter图文教程
2014/06/21 PHP
Code:findPosX 和 findPosY
2006/12/20 Javascript
JQUERY设置IFRAME的SRC值的代码
2010/11/30 Javascript
EXTjs4.0的store的findRecord的BUG演示代码
2013/06/08 Javascript
node.js解决获取图片真实文件类型的问题
2014/12/20 Javascript
JS实现向表格行添加新单元格的方法
2015/03/30 Javascript
基于JQuery的$.ajax方法进行异步请求导致页面闪烁的解决办法
2016/05/10 Javascript
JS组件Bootstrap Select2使用方法解析
2016/05/30 Javascript
Javascript中获取浏览器类型和操作系统版本等客户端信息常用代码
2016/06/28 Javascript
原生js获取元素样式的简单方法
2016/08/06 Javascript
canvas 画布在主流浏览器中的尺寸限制详细介绍
2016/12/15 Javascript
javascript如何用递归写一个简单的树形结构示例
2017/09/06 Javascript
JS 中使用Promise 实现红绿灯实例代码(demo)
2017/10/20 Javascript
jquery动态添加以及遍历option并获取特定样式名称的option方法
2018/01/29 jQuery
vue项目中引入noVNC远程桌面的方法
2018/03/05 Javascript
Node.js中package.json中库的版本号(~和^)
2019/04/02 Javascript
jQuery实现点击滚动到指定元素上的方法分析
2020/03/19 jQuery
[02:28]DOTA2英雄基础教程 狼人
2013/12/23 DOTA
Python Paramiko模块的安装与使用详解
2016/11/18 Python
Python定义一个跨越多行的字符串的多种方法小结
2018/07/19 Python
用python实现k近邻算法的示例代码
2018/09/06 Python
浅谈Python采集网页时正则表达式匹配换行符的问题
2018/12/20 Python
Python 读取用户指令和格式化打印实现解析
2019/09/02 Python
Python文件夹批处理操作代码实例
2020/07/21 Python
python如何实现递归转非递归
2021/02/25 Python
如何用Python进行时间序列分解和预测
2021/03/01 Python
英国高级百货公司:Harvey Nichols
2017/01/29 全球购物
男方父母婚礼答谢词
2014/01/25 职场文书
《桂林山水》教学反思
2014/02/08 职场文书
小学生常见病防治方案
2014/06/06 职场文书
建筑院校毕业生求职信
2014/06/13 职场文书
驾驶员安全责任书
2014/07/22 职场文书
社区志愿者活动方案
2014/08/18 职场文书
MySQL into_Mysql中replace与replace into用法案例详解
2021/09/14 MySQL