完美解决keras 读取多个hdf5文件进行训练的问题


Posted in Python onJuly 01, 2020

用keras进行大数据训练,为了加快训练,需要提前制作训练集。

由于HDF5的特性,所有数据需要一次性读入到内存中,才能保存。

为此,我采用分批次分为2个以上HDF5进行存储。

1、先读取每个标签下的图片,并设置标签

def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = [] 
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)

2、该标签下的数据集分割为训练集(train images),验证集(val images),训练标签(train labels),验证标签

(val labels)

def split_dataset(images, labels): 

 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
  
 #print(train_images.shape[0], 'train samples')
 #print(valid_images.shape[0], 'valid samples') 
 return train_images, valid_images, train_labels ,valid_labels

3、分割后的数据分别添加到总的训练集,验证集,训练标签,验证标签。

其次,清空原有的图片集和标签集,目的是节省内存。假如一次性读入多个标签的数据集与标签集,进行数据分割后,会占用大于单纯进行上述操作两倍以上的内存。

images = np.array(images) 
t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
for i in range(len(t_images)):
 train_images.append(t_images[i])
 train_labels.append(t_labels[i]) 
for j in range(len(v_images)):
 valid_images.append(v_images[j])
 valid_labels.append(v_labels[j])
if counter%50== 49:
 print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 
images = []
labels = [] 
counter = counter + 1 

print("train_images num: ", len(train_images), " ", "valid_images num: ",len(valid_images))

4、进行判断,直到读到自己自己分割的那个标签。

开始进行写入。写入之前,为了更好地训练模型,需要把对应的图片集和标签打乱顺序。

if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, t_images, v_images, t_labels ,v_labels)

对应打乱顺序并写入到HDF5

def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #shuffle
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # HDF5 initial
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

5、判断文件全部读入

def read_dataset(dirs):
 
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file
 dataset = h5py.File(path, "r")
 file = file.split('.')
 set_x_orig = dataset["x_"+file[0]].shape[0]
 set_y_orig = dataset["y_"+file[0]].shape[0]

 print(set_x_orig)
 print(set_y_orig)

6、训练中,采用迭代器读入数据

def generator(self, datagen, mode):
 
 passes=np.inf
 aug = ImageDataGenerator(
  featurewise_center = False,  
  samplewise_center = False,  
  featurewise_std_normalization = False, 
  samplewise_std_normalization = False, 
  zca_whitening = False,   
  rotation_range = 20,   
  width_shift_range = 0.2,  
  height_shift_range = 0.2,  
  horizontal_flip = True,  
  vertical_flip = False)  
 
 epochs = 0  
 # 默认是无限循环遍历
 
 while epochs < passes:
  # 遍历数据
  file_dir = os.listdir(self.data_path)
  for file in file_dir:
  #print(file)
  file_path = os.path.join(self.data_path,file)
  TRAIN_HDF5 = file_path +"/train.hdf5"
  VAL_HDF5 = file_path +"/val.hdf5"
  #TEST_HDF5 = file_path +"/test.hdf5"
  
  db_t = h5py.File(TRAIN_HDF5)
  numImages_t = db_t['y_train'].shape[0] 
  db_v = h5py.File(VAL_HDF5)
  numImages_v = db_v['y_val'].shape[0] 
  
  if mode == "train":  
   for i in np.arange(0, numImages_t, self.BS):
   
   images = db_t['x_train'][i: i+self.BS]
   labels = db_t['y_train'][i: i+self.BS]
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
      
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   
      
   # one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes)   
   
   yield ({'input_1': images}, {'softmax': labels})
    
  elif mode == "val":
   for i in np.arange(0, numImages_v, self.BS):
   images = db_v['x_val'][i: i+self.BS]
   labels = db_v['y_val'][i: i+self.BS] 
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
   
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   

   #one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes) 
    
   yield ({'input_1': images}, {'softmax': labels})
     
  epochs += 1

7、至此,就大功告成了

完整的代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 12 20:46:12 2018

@author: william_yue
"""
import os
import numpy as np
import cv2
import random
from scipy import misc
import h5py
from sklearn.model_selection import train_test_split
from keras import backend as K
K.clear_session()
from keras.utils import np_utils

IMAGE_SIZE = 128
 
# 加载数据集并按照交叉验证的原则划分数据集并进行相关预处理工作
def split_dataset(images, labels): 
 # 导入了sklearn库的交叉验证模块,利用函数train_test_split()来划分训练集和验证集
 # 划分出了20%的数据用于验证,80%用于训练模型
 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
 return train_images, valid_images, train_labels ,valid_labels
 
def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
#def data2h5(dirs_path, train_images, valid_images, test_images, train_labels ,valid_labels, test_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #采用标签与图片相同的顺序分别打乱训练集与验证集
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # 初始化HDF5写入
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

def read_dataset(dirs):
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file 
 file_read = os.listdir(path)
 for i in file_read:
  path_read = os.path.join(path, i)
  dataset = h5py.File(path_read, "r")
  i = i.split('.')
  set_x_orig = dataset["x_"+i[0]].shape[0]
  set_y_orig = dataset["y_"+i[0]].shape[0]
  print(set_x_orig)
  print(set_y_orig)

#循环读取每个标签集下的所有图片
def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = []
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)
   
 images = np.array(images) 
 t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
 for i in range(len(t_images)):
  train_images.append(t_images[i])
  train_labels.append(t_labels[i]) 
 for j in range(len(v_images)):
  valid_images.append(v_images[j])
  valid_labels.append(v_labels[j])
 if counter%50== 49:
  print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  
 images = []
 labels = [] 
 
 if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  print("train_images num: ", len(train_images), "  ", "valid_images num: ",len(valid_images)) 
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, train_images, valid_images, train_labels ,valid_labels)
  #read_dataset(dirs)
  print("File HDF5_%d "%num, " id done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  train_images = []
  valid_images = []
  train_labels = []
  valid_labels = [] 
 counter = counter + 1 
 print("All File HDF5 done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 read_dataset(data_path) 

#读取训练数据集的文件夹,把他们的名字返回给一个list
def read_name_list(path_name):
 name_list = []
 for child_dir in os.listdir(path_name):
 name_list.append(child_dir)
 return name_list

if __name__ == '__main__':
 path = "data"
 data_path = "data_hdf5_half"
 if not os.path.exists(data_path):
 os.makedirs(data_path)
 load_dataset(path,data_path)

以上这篇完美解决keras 读取多个hdf5文件进行训练的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python cookbook(数据结构与算法)从任意长度的可迭代对象中分解元素操作示例
Feb 13 Python
python将秒数转化为时间格式的实例
Sep 16 Python
Django框架中间件(Middleware)用法实例分析
May 24 Python
selenium2.0中常用的python函数汇总
Aug 05 Python
python笔记_将循环内容在一行输出的方法
Aug 08 Python
使用python实现unix2dos和dos2unix命令的例子
Aug 13 Python
DjangoWeb使用Datatable进行后端分页的实现
May 18 Python
Python Django中间件使用原理及流程分析
Jun 13 Python
Python Tkinter实例——模拟掷骰子
Oct 24 Python
Python 按比例获取样本数据或执行任务的实现代码
Dec 03 Python
如何在vscode中安装python库的方法步骤
Jan 06 Python
Python - 10行代码集2000张美女图
May 23 Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
python怎么判断素数
Jul 01 #Python
You might like
详解Laravel5.6 Passport实现Api接口认证
2018/07/27 PHP
newxtree.js代码
2007/03/13 Javascript
JavaScript高级程序设计(第3版)学习笔记10 再访js对象
2012/10/11 Javascript
表格奇偶行设置不同颜色的核心JS代码
2013/12/24 Javascript
Get中文乱码IE浏览器Get中文乱码解决方案
2013/12/26 Javascript
js中document.write的那点事
2014/12/12 Javascript
javascript实现密码强度显示
2015/03/18 Javascript
使用AngularJS编写较为优美的JavaScript代码指南
2015/06/19 Javascript
jQuery+PHP实现可编辑表格字段内容并实时保存
2015/10/09 Javascript
JavaScript自定义浏览器滚动条兼容IE、 火狐和chrome
2017/01/05 Javascript
JS验证不重复验证码
2017/02/10 Javascript
基于Vue.js实现tab滑块效果
2017/07/23 Javascript
JS简单生成由字母数字组合随机字符串示例
2018/05/25 Javascript
详解vue中localStorage的使用方法
2018/11/22 Javascript
layui使用及简单的三级联动实现教程
2020/12/01 Javascript
Python执行时间的计算方法小结
2017/03/17 Python
django开发之settings.py中变量的全局引用详解
2017/03/29 Python
python 队列基本定义与使用方法【初始化、赋值、判断等】
2019/10/24 Python
对python中return与yield的区别详解
2020/03/12 Python
Python基于read(size)方法读取超大文件
2020/03/12 Python
matplotlib常见函数之plt.rcParams、matshow的使用(坐标轴设置)
2021/01/05 Python
html5 视频播放解决方案
2016/11/06 HTML / CSS
英国最大的奢侈品零售网络商城:Flannels
2016/09/16 全球购物
Net-A-Porter美国官网:全球时尚奢侈品名站
2017/02/11 全球购物
Waterford美国官网:爱尔兰水晶制品品牌
2017/04/26 全球购物
The North Face北面美国官网:美国著名户外品牌
2018/09/15 全球购物
什么是Smarty变量操作符?如何使用Smarty变量操作符
2014/07/18 面试题
主治医师岗位职责
2013/12/10 职场文书
社区党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
乡镇安全生产月活动总结
2015/05/08 职场文书
2015仓库保管员年终工作总结
2015/05/13 职场文书
海上钢琴师的观后感
2015/06/11 职场文书
新学期开学寄语2016
2015/12/04 职场文书
导游词之麻姑仙境
2019/11/18 职场文书
用Python提取PDF表格的方法
2021/04/11 Python
Redis Cluster 字段模糊匹配及删除
2021/05/27 Redis