将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
剖析Python的Tornado框架中session支持的实现代码
Aug 21 Python
OpenCV实现人脸识别
Apr 07 Python
Python实现简单的获取图片爬虫功能示例
Jul 12 Python
Python实现识别手写数字 Python图片读入与处理
Mar 23 Python
Python实现的质因式分解算法示例
May 03 Python
Python3.5面向对象程序设计之类的继承和多态详解
Apr 24 Python
python Django框架实现web端分页呈现数据
Oct 31 Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 Python
浅谈Python中的模块
Jun 10 Python
Pyinstaller加密打包应用的示例代码
Jun 11 Python
python 基于opencv实现图像增强
Dec 23 Python
tensorboard 可视化之localhost:6006不显示的解决方案
May 22 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
php数组函数序列之array_splice() - 在数组任意位置插入元素
2011/11/07 PHP
Yii使用CLinkPager分页实例详解
2014/07/23 PHP
php实现的美国50个州选择列表实例
2015/04/20 PHP
Zend Framework基于Command命令行建立ZF项目的方法
2017/02/18 PHP
JavaScript constructor和instanceof,JSOO中的一对欢喜冤家
2009/05/25 Javascript
JavaScript 内置对象属性及方法集合
2010/07/04 Javascript
关于javascript DOM事件模型的两件事
2010/07/22 Javascript
javascript学习(一)构建自己的JS库
2013/01/02 Javascript
简单常用的幻灯片播放实现代码
2013/09/25 Javascript
js获取当前月的第一天和最后一天的小例子
2013/11/18 Javascript
jquery数组之存放checkbox全选值示例代码
2013/12/20 Javascript
Javascript:为input设置readOnly属性(示例讲解)
2013/12/25 Javascript
javascript得到当前页的来路即前一页地址的方法
2014/02/18 Javascript
jQuery中index()的用法分析
2014/09/05 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
js中substr,substring,indexOf,lastIndexOf,split,replace的用法详解
2015/11/09 Javascript
jquery实现具有嵌套功能的选项卡
2016/02/12 Javascript
在JavaScript中使用JSON数据
2016/02/15 Javascript
Bootstrap模态对话框的简单使用
2016/04/29 Javascript
Bootstrap布局方式详解
2016/05/27 Javascript
jQuery EasyUI封装简化操作
2016/09/18 Javascript
Javascript Event(事件)的传播与冒泡
2017/01/23 Javascript
微信小程序开发实现的IP地址查询功能示例
2019/03/28 Javascript
javascript读取本地文件和目录方法详解
2020/08/06 Javascript
[38:40]2018DOTA2亚洲邀请赛 4.6淘汰赛 mineski vs LGD 第一场
2018/04/10 DOTA
python脚本实现查找webshell的方法
2014/07/31 Python
python实现定时发送qq消息
2019/01/18 Python
python笔记之mean()函数实现求取均值的功能代码
2019/07/05 Python
MONNIER Frères英国官网:源自巴黎女士奢侈品配饰电商平台
2018/12/06 全球购物
Final类有什么特点
2012/04/25 面试题
数控技术专科生自我评价
2014/01/08 职场文书
十佳中学生事迹材料
2014/06/02 职场文书
房产授权委托书范本
2014/09/22 职场文书
房屋登记授权委托书范本
2014/10/09 职场文书
检讨书怎么写
2015/01/23 职场文书
2015年世界急救日宣传活动方案
2015/05/06 职场文书