将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python的Twisted框架中reactor事件管理器的用法
May 25 Python
Python中的复制操作及copy模块中的浅拷贝与深拷贝方法
Jul 02 Python
python xml.etree.ElementTree遍历xml所有节点实例详解
Dec 04 Python
python3模块smtplib实现发送邮件功能
May 22 Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 Python
python栈的基本定义与使用方法示例【初始化、赋值、入栈、出栈等】
Oct 24 Python
使用pygame写一个古诗词填空通关游戏
Dec 03 Python
tensorflow 保存模型和取出中间权重例子
Jan 24 Python
Python基于百度AI实现OCR文字识别
Apr 02 Python
Keras - GPU ID 和显存占用设定步骤
Jun 22 Python
Node.js 和 Python之间该选择哪个?
Aug 05 Python
用pushplus+python监控亚马逊到货动态推送微信
Jan 29 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
PHPExcel读取Excel文件的实现代码
2011/12/06 PHP
php截取字符串函数substr,iconv_substr,mb_substr示例以及优劣分析
2014/06/10 PHP
php获取数组元素中头一个数组元素值的实现方法
2014/12/20 PHP
php+jQuery.uploadify实现文件上传教程
2014/12/26 PHP
Codeigniter实现发送带附件的邮件
2015/03/19 PHP
Yii+MYSQL锁表防止并发情况下重复数据的方法
2016/07/14 PHP
php实现socket推送技术的示例
2017/12/20 PHP
js实现的仿新浪微博完美的时间组件升级版
2011/12/20 Javascript
提高jQuery性能优化的技巧
2015/08/03 Javascript
jQuery中值得注意的trigger方法浅析
2016/12/12 Javascript
Vue.js -- 过滤器使用总结
2017/02/18 Javascript
详谈jQuery中的一些正则匹配表达式
2017/03/08 Javascript
vue脚手架vue-cli的学习使用教程
2017/06/06 Javascript
Spring Boot/VUE中路由传递参数的实现代码
2018/03/02 Javascript
使用vue-router与v-if实现tab切换遇到的问题及解决方法
2018/09/07 Javascript
详解多页应用 Webpack4 配置优化与踩坑记录
2018/10/16 Javascript
Electron 调用命令行(cmd)
2019/09/23 Javascript
解决element-ui里的下拉多选框 el-select 时,默认值不可删除问题
2020/08/14 Javascript
使用Python编写简单的画图板程序的示例教程
2015/12/08 Python
tensorflow学习笔记之简单的神经网络训练和测试
2018/04/15 Python
Python3 安装PyQt5及exe打包图文教程
2019/01/08 Python
Python基于Opencv来快速实现人脸识别过程详解(完整版)
2019/07/11 Python
python实现复制文件到指定目录
2019/10/16 Python
Python运行异常管理解决方案
2020/03/09 Python
Python中qutip用法示例详解
2020/10/02 Python
OpenCV读取与写入图片的实现
2020/10/13 Python
AmazeUI 等分网格的实现示例
2020/08/25 HTML / CSS
Linux内核的同步机制是什么?主要有哪几种内核锁
2013/01/03 面试题
新闻网站实习自我鉴定
2013/09/25 职场文书
公司成立感言
2014/01/11 职场文书
宿舍使用违章电器检讨书
2014/01/12 职场文书
玄武湖导游词
2015/02/05 职场文书
数据结构课程设计心得体会
2016/01/15 职场文书
2016年先进教师个人事迹材料
2016/02/26 职场文书
CSS实现单选折叠菜单功能
2021/11/01 HTML / CSS
python基础之//、/与%的区别详解
2022/06/10 Python