完美解决keras 读取多个hdf5文件进行训练的问题


Posted in Python onJuly 01, 2020

用keras进行大数据训练,为了加快训练,需要提前制作训练集。

由于HDF5的特性,所有数据需要一次性读入到内存中,才能保存。

为此,我采用分批次分为2个以上HDF5进行存储。

1、先读取每个标签下的图片,并设置标签

def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = [] 
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)

2、该标签下的数据集分割为训练集(train images),验证集(val images),训练标签(train labels),验证标签

(val labels)

def split_dataset(images, labels): 

 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
  
 #print(train_images.shape[0], 'train samples')
 #print(valid_images.shape[0], 'valid samples') 
 return train_images, valid_images, train_labels ,valid_labels

3、分割后的数据分别添加到总的训练集,验证集,训练标签,验证标签。

其次,清空原有的图片集和标签集,目的是节省内存。假如一次性读入多个标签的数据集与标签集,进行数据分割后,会占用大于单纯进行上述操作两倍以上的内存。

images = np.array(images) 
t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
for i in range(len(t_images)):
 train_images.append(t_images[i])
 train_labels.append(t_labels[i]) 
for j in range(len(v_images)):
 valid_images.append(v_images[j])
 valid_labels.append(v_labels[j])
if counter%50== 49:
 print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 
images = []
labels = [] 
counter = counter + 1 

print("train_images num: ", len(train_images), " ", "valid_images num: ",len(valid_images))

4、进行判断,直到读到自己自己分割的那个标签。

开始进行写入。写入之前,为了更好地训练模型,需要把对应的图片集和标签打乱顺序。

if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, t_images, v_images, t_labels ,v_labels)

对应打乱顺序并写入到HDF5

def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #shuffle
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # HDF5 initial
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

5、判断文件全部读入

def read_dataset(dirs):
 
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file
 dataset = h5py.File(path, "r")
 file = file.split('.')
 set_x_orig = dataset["x_"+file[0]].shape[0]
 set_y_orig = dataset["y_"+file[0]].shape[0]

 print(set_x_orig)
 print(set_y_orig)

6、训练中,采用迭代器读入数据

def generator(self, datagen, mode):
 
 passes=np.inf
 aug = ImageDataGenerator(
  featurewise_center = False,  
  samplewise_center = False,  
  featurewise_std_normalization = False, 
  samplewise_std_normalization = False, 
  zca_whitening = False,   
  rotation_range = 20,   
  width_shift_range = 0.2,  
  height_shift_range = 0.2,  
  horizontal_flip = True,  
  vertical_flip = False)  
 
 epochs = 0  
 # 默认是无限循环遍历
 
 while epochs < passes:
  # 遍历数据
  file_dir = os.listdir(self.data_path)
  for file in file_dir:
  #print(file)
  file_path = os.path.join(self.data_path,file)
  TRAIN_HDF5 = file_path +"/train.hdf5"
  VAL_HDF5 = file_path +"/val.hdf5"
  #TEST_HDF5 = file_path +"/test.hdf5"
  
  db_t = h5py.File(TRAIN_HDF5)
  numImages_t = db_t['y_train'].shape[0] 
  db_v = h5py.File(VAL_HDF5)
  numImages_v = db_v['y_val'].shape[0] 
  
  if mode == "train":  
   for i in np.arange(0, numImages_t, self.BS):
   
   images = db_t['x_train'][i: i+self.BS]
   labels = db_t['y_train'][i: i+self.BS]
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
      
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   
      
   # one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes)   
   
   yield ({'input_1': images}, {'softmax': labels})
    
  elif mode == "val":
   for i in np.arange(0, numImages_v, self.BS):
   images = db_v['x_val'][i: i+self.BS]
   labels = db_v['y_val'][i: i+self.BS] 
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
   
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   

   #one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes) 
    
   yield ({'input_1': images}, {'softmax': labels})
     
  epochs += 1

7、至此,就大功告成了

完整的代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 12 20:46:12 2018

@author: william_yue
"""
import os
import numpy as np
import cv2
import random
from scipy import misc
import h5py
from sklearn.model_selection import train_test_split
from keras import backend as K
K.clear_session()
from keras.utils import np_utils

IMAGE_SIZE = 128
 
# 加载数据集并按照交叉验证的原则划分数据集并进行相关预处理工作
def split_dataset(images, labels): 
 # 导入了sklearn库的交叉验证模块,利用函数train_test_split()来划分训练集和验证集
 # 划分出了20%的数据用于验证,80%用于训练模型
 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
 return train_images, valid_images, train_labels ,valid_labels
 
def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
#def data2h5(dirs_path, train_images, valid_images, test_images, train_labels ,valid_labels, test_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #采用标签与图片相同的顺序分别打乱训练集与验证集
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # 初始化HDF5写入
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

def read_dataset(dirs):
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file 
 file_read = os.listdir(path)
 for i in file_read:
  path_read = os.path.join(path, i)
  dataset = h5py.File(path_read, "r")
  i = i.split('.')
  set_x_orig = dataset["x_"+i[0]].shape[0]
  set_y_orig = dataset["y_"+i[0]].shape[0]
  print(set_x_orig)
  print(set_y_orig)

#循环读取每个标签集下的所有图片
def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = []
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)
   
 images = np.array(images) 
 t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
 for i in range(len(t_images)):
  train_images.append(t_images[i])
  train_labels.append(t_labels[i]) 
 for j in range(len(v_images)):
  valid_images.append(v_images[j])
  valid_labels.append(v_labels[j])
 if counter%50== 49:
  print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  
 images = []
 labels = [] 
 
 if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  print("train_images num: ", len(train_images), "  ", "valid_images num: ",len(valid_images)) 
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, train_images, valid_images, train_labels ,valid_labels)
  #read_dataset(dirs)
  print("File HDF5_%d "%num, " id done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  train_images = []
  valid_images = []
  train_labels = []
  valid_labels = [] 
 counter = counter + 1 
 print("All File HDF5 done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 read_dataset(data_path) 

#读取训练数据集的文件夹,把他们的名字返回给一个list
def read_name_list(path_name):
 name_list = []
 for child_dir in os.listdir(path_name):
 name_list.append(child_dir)
 return name_list

if __name__ == '__main__':
 path = "data"
 data_path = "data_hdf5_half"
 if not os.path.exists(data_path):
 os.makedirs(data_path)
 load_dataset(path,data_path)

以上这篇完美解决keras 读取多个hdf5文件进行训练的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3中多线程编程的队列运作示例
Apr 16 Python
python分析网页上所有超链接的方法
May 08 Python
一个基于flask的web应用诞生 使用模板引擎和表单插件(2)
Apr 11 Python
详解python之多进程和进程池(Processing库)
Jun 09 Python
python导出hive数据表的schema实例代码
Jan 22 Python
Python实现的redis分布式锁功能示例
May 29 Python
如何使用Python进行OCR识别图片中的文字
Apr 01 Python
python实现websocket的客户端压力测试
Jun 25 Python
详解Python并发编程之创建多线程的几种方法
Aug 23 Python
pygame实现打字游戏
Feb 19 Python
python实现数学模型(插值、拟合和微分方程)
Nov 13 Python
python爬虫智能翻页批量下载文件的实例详解
Feb 02 Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
python怎么判断素数
Jul 01 #Python
You might like
php中文件上传的安全问题
2006/10/09 PHP
理解php原理的opcodes(操作码)
2010/10/26 PHP
基于php常用函数总结(数组,字符串,时间,文件操作)
2013/06/27 PHP
通过dbi使用perl连接mysql数据库的方法
2014/04/16 PHP
使用Zookeeper分布式部署PHP应用程序
2019/03/15 PHP
Prototype1.5 rc2版指南最后一篇之Position
2007/01/10 Javascript
jQuery中的.bind()、.live()和.delegate()之间区别分析
2011/06/08 Javascript
js创建子窗口并且回传值示例代码
2013/07/02 Javascript
JavaScript中的正则表达式简明总结
2014/04/04 Javascript
使用javascript实现json数据以csv格式下载
2015/01/09 Javascript
JavaScript定时显示广告代码分享
2015/03/02 Javascript
jQuery插件实现大图全屏图片相册
2015/03/14 Javascript
arcgis for js 修改infowindow样式的方法
2016/11/02 Javascript
js微信扫描二维码登录网站技术原理
2016/12/01 Javascript
jQuery Validate验证框架详解(推荐)
2016/12/17 Javascript
bootstrap fileinput组件整合Springmvc上传图片到本地磁盘
2017/05/11 Javascript
JavaScript实现秒杀时钟倒计时
2019/09/29 Javascript
Vue——解决报错 Computed property &quot;****&quot; was assigned to but it has no setter.
2020/12/19 Vue.js
JS实现点击掉落特效
2021/01/29 Javascript
[50:24]VGJ.S vs Pain 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
用Python实现QQ游戏大家来找茬辅助工具
2014/09/14 Python
python实现两个文件合并功能
2018/04/01 Python
python实现自主查询实时天气
2018/06/22 Python
python字典一键多值实例代码分享
2019/06/14 Python
图解python全局变量与局部变量相关知识
2019/11/02 Python
python利用百度云接口实现车牌识别的示例
2020/02/21 Python
Ellos瑞典官网:北欧地区时尚、美容和住宅领域领先的电子商务网站
2019/11/21 全球购物
PHP数据运算类型都有哪些
2013/11/05 面试题
农民工工资发放承诺书
2014/03/31 职场文书
人身意外保险授权委托书
2014/10/01 职场文书
城镇居民医疗保险工作总结
2015/08/10 职场文书
《语言的突破》读后感3篇
2019/12/12 职场文书
Java实现聊天机器人完善版
2021/07/04 Java/Android
Spring Boot 排除某个类加载注入IOC的操作
2021/08/02 Java/Android
IDEA 2022 Translation 未知错误 翻译文档失败
2022/04/24 Java/Android
clear 万能清除浮动(clearfix:after)
2023/05/21 HTML / CSS