keras 读取多标签图像数据方式


Posted in Python onJune 12, 2020

我所接触的多标签数据,主要包括两类:

1、一张图片属于多个标签,比如,data:一件蓝色的上衣图片.jpg,label:蓝色,上衣。其中label包括两类标签,label1第一类:上衣,裤子,外套。label2第二类,蓝色,黑色,红色。这样两个输出label1,label2都是是分类,我们可以直接把label1和label2整合为一个label,直接编码,比如[蓝色,上衣]编码为[011011]。这样模型的输出也只需要一个输出。实现了多分类。

2、一张图片属于多个标签,但是几个标签不全是分类。比如data:一张结婚现场的图片.jpg,label:高兴,3(表示高兴程度)。这时label1是分类,label2时回归。这种情况就需要多个标签,模型需要多个输出。【其实最好的例子,就是目标检测,不但检测什么物体(分类),还要检测到物体的坐标(回归)】

在这里我主要针对第二种情况加以说明:

keras的ImageDataGenerator.flow_from_directory 只能简单的读取单标签数据。所以我自己写了个data_generate,来生成bathsize多标签数据

keras 读取多标签图像数据方式

#此模块主要用来读取数据集,返回一个数据可迭代对象
#重点是,此模块分批次的把图像读入内存的,而不是一次全读入,有效的减少了内存溢出
import os
import cv2
import numpy as np
import keras
from random import shuffle

#目标图像大小
image_size= (229, 229, 3)
#类别编码
class_dict=dict(zip(['neg','pos','neu'],[0,1,2]))
#处理.txt文件,并加载图片文件夹里的图片名
#txt_path,txt文件路径,data_path,图片文件夹路径

def read_txt(txt_path,data_path):
 # 中间数组
 labels_class = []
 labels_score = []
 with open(txt_path) as f:
 lines_list = f.readlines()
 for lines in lines_list:
  line = lines.split(' ')
  labels_class.append(line[0].rstrip(".jpg"))
  labels_score.append(line[2])
 labels_dict=dict(zip(labels_class,labels_score))
 #处理图片数据集
 all_picture_name = os.listdir(data_path)
 #打乱数据集
 shuffle(all_picture_name)
 all_picture_path=[os.path.join(data_path,one)for one in all_picture_name]
 return all_picture_name,all_picture_path,labels_dict

class data_generate:
 def __init__(self,all_piture_name,all_picture_path,labels_dict,batch_size):
 self.index=0
 self.all_picture_name=all_piture_name
 self.all_picture_path=all_picture_path
 self.labels_dict=labels_dict
 self.batch_size = batch_size
 def get_mini_batch(self):
  while True:
  batch_images=[]
  batch_labels=[]
  batch_class=[]
  batch_score=[]
  for i in range(self.batch_size):
  if(self.index==len(self.all_picture_name)):
   self.index=0

  bgr_image = cv2.imread(self.all_picture_path[self.index])
  if len(bgr_image.shape) == 2: # 若是灰度图则转为三通道
   bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
  rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
  rgb_image=cv2.resize(rgb_image,(image_size[0], image_size[1]))
  img = np.array(rgb_image)
  img=keras.applications.inception_v3.preprocess_input(img)
  batch_images.append(img)
  #label=[]
  label1=self.all_picture_name[self.index].rstrip(".jpg")
  batch_class.append(keras.utils.to_categorical(class_dict[label1[:3]], 3))
  batch_score.append(np.array(self.labels_dict[label1]))
  #batch_labels.append(label)
  self.index+=1
  batch_images=np.array(batch_images)
  batch_class = np.array(batch_class)
  batch_score = np.array(batch_score)
  #注意label的生成batch_class,一个单独数组,batch_score一个单独的数组
  '''
  注释掉的这段代码生成的label是错误的。
  batch_images=[]
  batch_labels=[]
  for i in range(self.batch_size):
  if(self.index==len(self.images)):
   self.index=0
  batch_images.append(self.images[self.index])
  batch_labels.append(self.labels[self.index])
  self.index+=1
  batch_images=np.array(batch_images)
  batch_labels=np.array(batch_labels)
  yield batch_images,batch_labels
  '''
  yield batch_images,[batch_class,batch_score]

接下来就是放入keras.fit_generate中了

history=model.fit_generator(generator=train_data.get_mini_batch(),
   steps_per_epoch=146,
   epochs=300,
   validation_data=test_data.get_mini_batch(),
   validation_steps=34,
   )

以上这篇keras 读取多标签图像数据方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python时间模块中的datetime模块
Jan 13 Python
详解Python读取配置文件模块ConfigParser
May 11 Python
Python编程pygal绘图实例之XY线
Dec 09 Python
Python用于学习重要算法的模块pygorithm实例浅析
Aug 16 Python
破解安装Pycharm的方法
Oct 19 Python
在PyCharm中三步完成PyPy解释器的配置的方法
Oct 29 Python
Pandas时间序列重采样(resample)方法中closed、label的作用详解
Dec 10 Python
PyTorch中反卷积的用法详解
Dec 30 Python
Jupyter 无法下载文件夹如何实现曲线救国
Apr 22 Python
Python正则表达式如何匹配中文
May 27 Python
python 实现表情识别
Nov 21 Python
基于Python实现股票收益率分析
Apr 02 Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
python + selenium 刷B站播放量的实例代码
Jun 12 #Python
解决Keras自带数据集与预训练model下载太慢问题
Jun 12 #Python
keras导入weights方式
Jun 12 #Python
You might like
PHP 加密与解密的斗争
2009/04/17 PHP
php封装的图片(缩略图)处理类完整实例
2016/10/19 PHP
PHP基于curl实现模拟微信浏览器打开微信链接的方法示例
2019/02/15 PHP
PHP之多条件混合筛选功能的实现方法
2019/10/09 PHP
基于jquery的一个拖拽到指定区域内的效果
2011/09/21 Javascript
ASP.NET jQuery 实例5 (显示CheckBoxList成员选中的内容)
2012/01/13 Javascript
Array.prototype.concat不是通用方法反驳[译]
2012/09/20 Javascript
JavaScript中的字符串操作详解
2013/11/12 Javascript
NodeJS中Buffer模块详解
2015/01/07 NodeJs
JavaScript iframe数据共享接口实现方法
2016/01/06 Javascript
javascript实现瀑布流加载图片原理
2016/02/02 Javascript
jQuery插件echarts实现的循环生成图效果示例【附demo源码下载】
2017/03/04 Javascript
利用node.js+mongodb如何搭建一个简单登录注册的功能详解
2017/07/30 Javascript
浅谈PDF.js使用心得
2018/06/07 Javascript
AngularJS修改model值时,显示内容不变的实例
2018/09/13 Javascript
jquery分页优化操作实例分析
2019/08/23 jQuery
小程序如何支持使用 async/await详解
2019/09/12 Javascript
vue 实现 rem 布局或vw 布局的方法
2019/11/13 Javascript
详解vue高级特性
2020/06/09 Javascript
详解template标签用法(含vue中的用法总结)
2021/01/12 Vue.js
SublimeText 2编译python出错的解决方法(The system cannot find the file specified)
2013/11/27 Python
Python itertools模块详解
2015/05/09 Python
举例讲解Python中字典的合并值相加与异或对比
2016/06/04 Python
python中使用正则表达式的后向搜索肯定模式(推荐)
2017/11/11 Python
PyTorch快速搭建神经网络及其保存提取方法详解
2018/04/28 Python
Python决策树之基于信息增益的特征选择示例
2018/06/25 Python
在Python中利用pickle保存变量的实例
2019/12/30 Python
Python使用Pyqt5实现简易浏览器(最新版本测试过)
2020/04/27 Python
日本最新流行服饰网购:Nissen
2016/07/24 全球购物
Smilodox官方运动服装店:从运动服到健身配件
2020/08/27 全球购物
财务担保书范文
2014/04/02 职场文书
小学生暑假家长评语
2014/04/17 职场文书
护士节活动总结
2014/08/29 职场文书
Nginx使用Lua模块实现WAF的原理解析
2021/09/04 Servers
关于CSS自定义属性与前端页面的主题切换问题
2022/03/21 HTML / CSS
JS实现简单九宫格抽奖
2022/06/28 Javascript