将自己的数据集制作成TFRecord格式教程


Posted in Python onFebruary 17, 2020

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

将自己的数据集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename): 
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python学习之编写查询ip程序
Feb 27 Python
Python实现多线程抓取网页功能实例详解
Jun 08 Python
Python安装图文教程 Pycharm安装教程
Mar 27 Python
详解Python中正则匹配TAB及空格的小技巧
Jul 26 Python
给我一面国旗 python帮你实现
Sep 30 Python
tensorflow实现对张量数据的切片操作方式
Jan 19 Python
Python闭包装饰器使用方法汇总
Jun 29 Python
Python私有属性私有方法应用实例解析
Sep 15 Python
Django 权限管理(permissions)与用户组(group)详解
Nov 30 Python
python - timeit 时间模块
Apr 06 Python
python 进阶学习之python装饰器小结
Sep 04 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
tensorflow 实现数据类型转换
Feb 17 #Python
Django Haystack 全文检索与关键词高亮的实现
Feb 17 #Python
python使用docx模块读写docx文件的方法与docx模块常用方法详解
Feb 17 #Python
python itsdangerous模块的具体使用方法
Feb 17 #Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 #Python
TensorFlow通过文件名/文件夹名获取标签,并加入队列的实现
Feb 17 #Python
Django 项目通过加载不同env文件来区分不同环境
Feb 17 #Python
You might like
非洲第一个咖啡超凡杯大赛承办国—卢旺达的咖啡怎么样
2021/03/03 咖啡文化
多文件上载系统完整版
2006/10/09 PHP
回答PHPCHINA上的几个问题:URL映射
2007/02/14 PHP
几个实用的PHP内置函数使用指南
2014/11/27 PHP
thinkPHP2.1自定义标签库的导入方法详解
2016/07/20 PHP
PHP+MySQL实现模糊查询员工信息功能示例
2018/06/01 PHP
解析js原生方法创建表格效率测试
2013/07/08 Javascript
ExtJs中gridpanel分组后组名排序实例代码
2013/12/02 Javascript
javascript实现浏览器窗口传递参数的方法
2014/09/03 Javascript
基于Jquery实现表单验证
2020/07/20 Javascript
深入理解事件冒泡(Bubble)和事件捕捉(capture)
2016/05/28 Javascript
jQuery基础_入门必看知识点
2016/07/04 Javascript
用jQuery的AJax实现异步访问、异步加载
2016/11/02 Javascript
微信小程序 实战程序简易新闻的制作
2017/01/09 Javascript
angular ng-click防止重复提交实例
2017/06/16 Javascript
JScript实现表格的简单操作
2017/08/15 Javascript
Vue的Class与Style绑定的方法
2017/09/01 Javascript
微信小程序使用navigateTo数据传递的实例
2017/09/26 Javascript
浅谈Vue-cli 命令行工具分析
2017/11/22 Javascript
AngularJS实现的2048小游戏功能【附源码下载】
2018/01/03 Javascript
[48:45]Ti4 循环赛第二日 NEWBEE vs EG
2014/07/11 DOTA
python和bash统计CPU利用率的方法
2015/07/10 Python
python实现隐马尔科夫模型HMM
2018/03/25 Python
详解Python if-elif-else知识点
2018/06/11 Python
python使用for循环计算0-100的整数的和方法
2019/02/01 Python
pytorch多进程加速及代码优化方法
2019/08/19 Python
python查看数据类型的方法
2019/10/12 Python
小 200 行 Python 代码制作一个换脸程序
2020/05/12 Python
python获取整个网页源码的方法
2020/08/03 Python
css3过渡_动力节点Java学院整理
2017/07/11 HTML / CSS
京东国际站:JOYBUY
2017/11/23 全球购物
银行毕业实习自我鉴定
2013/09/19 职场文书
忠诚教育心得体会
2014/09/03 职场文书
2014年林业工作总结
2014/12/05 职场文书
2015年材料员工作总结
2015/04/30 职场文书
Oracle表空间与权限的深入讲解
2021/11/17 Oracle