将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中map和列表推导效率比较实例分析
Jun 17 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
linux环境下python中MySQLdb模块的安装方法
Jun 16 Python
详解supervisor使用教程
Nov 21 Python
python实现最长公共子序列
May 22 Python
Python爬虫包BeautifulSoup简介与安装(一)
Jun 17 Python
Python实现图片转字符画的代码实例
Feb 22 Python
python Tcp协议发送和接收信息的例子
Jul 22 Python
django中使用Celery 布式任务队列过程详解
Jul 29 Python
python numpy 常用随机数的产生方法的实现
Aug 21 Python
python自动化工具之pywinauto实例详解
Aug 26 Python
Python进程池Pool应用实例分析
Nov 27 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
PHP获取网址的顶级域名函数代码
2012/09/24 PHP
php获取网页请求状态程序示例
2014/06/17 PHP
PHP实现更新中间关联表数据的两种方法
2014/09/01 PHP
php使用redis的有序集合zset实现延迟队列应用示例
2020/02/20 PHP
多个iframe自动调整大小的问题
2006/09/18 Javascript
比较全面的event对像在IE与FF中的区别 推荐
2009/09/21 Javascript
js获取浏览器的可视区域尺寸的实现代码
2011/11/30 Javascript
javascript计时器事件使用详解
2014/01/07 Javascript
js调用webservice构造SOAP进行身份验证
2016/04/27 Javascript
jQuery页面加载初始化的3种方法(推荐)
2016/06/02 Javascript
JS中判断字符串中出现次数最多的字符及出现的次数的简单实例
2016/06/03 Javascript
小程序开发实战:实现九宫格界面的导航的代码实现
2017/01/19 Javascript
[js高手之路]原型式继承与寄生式继承详解
2017/08/28 Javascript
vue2.0在table中实现全选和反选的示例代码
2017/11/04 Javascript
Vue仿今日头条实例详解
2018/02/06 Javascript
浅谈React中的元素、组件、实例和节点
2018/02/27 Javascript
详解关于微信setData回调函数中的坑
2019/02/18 Javascript
Vue CLI3.0中使用jQuery和Bootstrap的方法
2019/02/28 jQuery
[01:00]一分钟回顾2018DOTA2亚洲邀请赛现场活动
2018/04/07 DOTA
python实现的jpg格式图片修复代码
2015/04/21 Python
python数组复制拷贝的实现方法
2015/06/09 Python
详解Django框架中用户的登录和退出的实现
2015/07/23 Python
在python中实现调用可执行文件.exe的3种方法
2019/07/07 Python
python如何爬取网站数据并进行数据可视化
2019/07/08 Python
python 怎样将dataframe中的字符串日期转化为日期的方法
2019/09/26 Python
python实现修改固定模式的字符串内容操作示例
2019/12/30 Python
使用pytorch和torchtext进行文本分类的实例
2020/01/08 Python
六种酷炫Python运行进度条效果的实现代码
2020/07/17 Python
Python3 pyecharts生成Html文件柱状图及折线图代码实例
2020/09/29 Python
python从PDF中提取数据的示例
2020/10/30 Python
全球知名鞋履品牌授权零售商:Journeys
2016/09/17 全球购物
AVON雅芳官网:世界上最大的美容化妆品公司之一
2016/11/02 全球购物
排序都有哪几种方法?请列举。用JAVA实现一个快速排序
2014/02/16 面试题
出租房屋协议书
2014/09/14 职场文书
2014财务人员自我评价范文
2014/09/21 职场文书
python 机器学习的标准化、归一化、正则化、离散化和白化
2021/04/16 Python