将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于matplotlib实现绘制三维图形功能示例
Jan 18 Python
Python设计模式之组合模式原理与用法实例分析
Jan 11 Python
对python列表里的字典元素去重方法详解
Jan 21 Python
Python使用numpy模块实现矩阵和列表的连接操作方法
Jun 26 Python
Django model 中设置联合约束和联合索引的方法
Aug 06 Python
Python爬虫:url中带字典列表参数的编码转换方法
Aug 21 Python
python上传时包含boundary时的解决方法
Apr 08 Python
python由已知数组快速生成新数组的方法
Apr 08 Python
Python Socket TCP双端聊天功能实现过程详解
Jun 15 Python
python中selenium库的基本使用详解
Jul 31 Python
Ubuntu20.04环境安装tensorflow2的方法步骤
Jan 29 Python
python 基于DDT实现数据驱动测试
Feb 18 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
解析PHPExcel使用的常用说明以及把PHPExcel整合进CI框架的介绍
2013/06/24 PHP
php中strtotime函数用法详解
2014/11/15 PHP
PHPStorm+XDebug进行调试图文教程
2016/06/13 PHP
PHP实现表单提交时去除斜杠的方法
2016/12/26 PHP
简单谈谈PHP中的trait
2017/02/25 PHP
PHP 枚举类型的管理与设计知识点总结
2020/02/13 PHP
js TextArea的选中区域处理
2010/12/28 Javascript
js鼠标点击事件在各个浏览器中的写法及Event对象属性介绍
2013/01/24 Javascript
jQuery 关于伪类选择符的使用说明
2013/04/24 Javascript
文本框回车提交与禁止提交示例
2013/09/27 Javascript
jQuery中animate动画第二次点击事件没反应
2015/05/07 Javascript
javascript实现textarea中tab键的缩排处理方法
2015/06/26 Javascript
jQuery实现table表格信息的展开和缩小功能示例
2018/07/21 jQuery
iphone刘海屏页面适配方法
2019/05/07 Javascript
vue中使用v-for时为什么不能用index作为key
2020/04/04 Javascript
Vue全局使用less样式,组件使用全局样式文件中定义的变量操作
2020/10/21 Javascript
vue实现树状表格效果
2020/12/29 Vue.js
python写入中英文字符串到文件的方法
2015/05/06 Python
在Linux中通过Python脚本访问mdb数据库的方法
2015/05/06 Python
numpy中索引和切片详解
2017/12/15 Python
深入分析python中整型不会溢出问题
2018/06/18 Python
Python模拟简单电梯调度算法示例
2018/08/20 Python
python 使用re.search()筛选后 选取部分结果的方法
2018/11/28 Python
Python操作MySQL数据库实例详解【安装、连接、增删改查等】
2020/01/17 Python
VSCode配合pipenv搞定虚拟环境的实现方法
2020/05/17 Python
PyCharm中关于安装第三方包的三个建议
2020/09/17 Python
一款基于css3和jquery实现的动画显示弹出层按钮教程
2015/01/04 HTML / CSS
html5 浏览器支持 如何让所有的浏览器都支持HTML5标签样式
2012/12/07 HTML / CSS
纬创Java面试题笔试题
2014/10/02 面试题
DataReader和DataSet的异同
2014/12/31 面试题
元旦联欢会感言
2014/03/04 职场文书
企业文化宣传标语
2014/06/09 职场文书
学生顶撞老师的检讨书
2014/09/17 职场文书
护理培训心得体会
2016/01/22 职场文书
解决SpringBoot跨域的三种方式
2021/06/26 Java/Android
星际争霸 Light vs Action 一场把教主看到鬼畜的比赛
2022/04/01 星际争霸