完美解决keras 读取多个hdf5文件进行训练的问题


Posted in Python onJuly 01, 2020

用keras进行大数据训练,为了加快训练,需要提前制作训练集。

由于HDF5的特性,所有数据需要一次性读入到内存中,才能保存。

为此,我采用分批次分为2个以上HDF5进行存储。

1、先读取每个标签下的图片,并设置标签

def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = [] 
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)

2、该标签下的数据集分割为训练集(train images),验证集(val images),训练标签(train labels),验证标签

(val labels)

def split_dataset(images, labels): 

 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
  
 #print(train_images.shape[0], 'train samples')
 #print(valid_images.shape[0], 'valid samples') 
 return train_images, valid_images, train_labels ,valid_labels

3、分割后的数据分别添加到总的训练集,验证集,训练标签,验证标签。

其次,清空原有的图片集和标签集,目的是节省内存。假如一次性读入多个标签的数据集与标签集,进行数据分割后,会占用大于单纯进行上述操作两倍以上的内存。

images = np.array(images) 
t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
for i in range(len(t_images)):
 train_images.append(t_images[i])
 train_labels.append(t_labels[i]) 
for j in range(len(v_images)):
 valid_images.append(v_images[j])
 valid_labels.append(v_labels[j])
if counter%50== 49:
 print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 
images = []
labels = [] 
counter = counter + 1 

print("train_images num: ", len(train_images), " ", "valid_images num: ",len(valid_images))

4、进行判断,直到读到自己自己分割的那个标签。

开始进行写入。写入之前,为了更好地训练模型,需要把对应的图片集和标签打乱顺序。

if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, t_images, v_images, t_labels ,v_labels)

对应打乱顺序并写入到HDF5

def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #shuffle
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # HDF5 initial
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

5、判断文件全部读入

def read_dataset(dirs):
 
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file
 dataset = h5py.File(path, "r")
 file = file.split('.')
 set_x_orig = dataset["x_"+file[0]].shape[0]
 set_y_orig = dataset["y_"+file[0]].shape[0]

 print(set_x_orig)
 print(set_y_orig)

6、训练中,采用迭代器读入数据

def generator(self, datagen, mode):
 
 passes=np.inf
 aug = ImageDataGenerator(
  featurewise_center = False,  
  samplewise_center = False,  
  featurewise_std_normalization = False, 
  samplewise_std_normalization = False, 
  zca_whitening = False,   
  rotation_range = 20,   
  width_shift_range = 0.2,  
  height_shift_range = 0.2,  
  horizontal_flip = True,  
  vertical_flip = False)  
 
 epochs = 0  
 # 默认是无限循环遍历
 
 while epochs < passes:
  # 遍历数据
  file_dir = os.listdir(self.data_path)
  for file in file_dir:
  #print(file)
  file_path = os.path.join(self.data_path,file)
  TRAIN_HDF5 = file_path +"/train.hdf5"
  VAL_HDF5 = file_path +"/val.hdf5"
  #TEST_HDF5 = file_path +"/test.hdf5"
  
  db_t = h5py.File(TRAIN_HDF5)
  numImages_t = db_t['y_train'].shape[0] 
  db_v = h5py.File(VAL_HDF5)
  numImages_v = db_v['y_val'].shape[0] 
  
  if mode == "train":  
   for i in np.arange(0, numImages_t, self.BS):
   
   images = db_t['x_train'][i: i+self.BS]
   labels = db_t['y_train'][i: i+self.BS]
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
      
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   
      
   # one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes)   
   
   yield ({'input_1': images}, {'softmax': labels})
    
  elif mode == "val":
   for i in np.arange(0, numImages_v, self.BS):
   images = db_v['x_val'][i: i+self.BS]
   labels = db_v['y_val'][i: i+self.BS] 
   
   if K.image_data_format() == 'channels_first':
   
    images = images.reshape(images.shape[0], 3, IMAGE_SIZE,IMAGE_SIZE) 
   else:
    images = images.reshape(images.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3) 
   
   images = images.astype('float32')
   images = images/255   
   
   if datagen :
    (images,labels) = next(aug.flow(images,labels,batch_size = self.BS))   

   #one-hot编码
   if self.binarize:
    labels = np_utils.to_categorical(labels,self.classes) 
    
   yield ({'input_1': images}, {'softmax': labels})
     
  epochs += 1

7、至此,就大功告成了

完整的代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 12 20:46:12 2018

@author: william_yue
"""
import os
import numpy as np
import cv2
import random
from scipy import misc
import h5py
from sklearn.model_selection import train_test_split
from keras import backend as K
K.clear_session()
from keras.utils import np_utils

IMAGE_SIZE = 128
 
# 加载数据集并按照交叉验证的原则划分数据集并进行相关预处理工作
def split_dataset(images, labels): 
 # 导入了sklearn库的交叉验证模块,利用函数train_test_split()来划分训练集和验证集
 # 划分出了20%的数据用于验证,80%用于训练模型
 train_images, valid_images, train_labels, valid_labels = train_test_split(images,\
 labels, test_size = 0.2, random_state = random.randint(0, 100)) 
 return train_images, valid_images, train_labels ,valid_labels
 
def data2h5(dirs_path, train_images, valid_images, train_labels ,valid_labels):
 
#def data2h5(dirs_path, train_images, valid_images, test_images, train_labels ,valid_labels, test_labels):
 
 TRAIN_HDF5 = dirs_path + '/' + "train.hdf5"
 VAL_HDF5 = dirs_path + '/' + "val.hdf5"
 
 #采用标签与图片相同的顺序分别打乱训练集与验证集
 state1 = np.random.get_state()
 np.random.shuffle(train_images)
 np.random.set_state(state1)
 np.random.shuffle(train_labels)
 
 state2 = np.random.get_state()
 np.random.shuffle(valid_images)
 np.random.set_state(state2)
 np.random.shuffle(valid_labels)
 
 datasets = [
 ("train",train_images,train_labels,TRAIN_HDF5),
 ("val",valid_images,valid_labels,VAL_HDF5)]
 
 for (dType,images,labels,outputPath) in datasets:
 # 初始化HDF5写入
 f = h5py.File(outputPath, "w")
 f.create_dataset("x_"+dType, data=images)
 f.create_dataset("y_"+dType, data=labels)
 #f.create_dataset("x_"+dType, data=images, compression="gzip", compression_opts=9)
 #f.create_dataset("y_"+dType, data=labels, compression="gzip", compression_opts=9)
 f.close()

def read_dataset(dirs):
 files = os.listdir(dirs)
 print(files)
 for file in files:
 path = dirs+'/' + file 
 file_read = os.listdir(path)
 for i in file_read:
  path_read = os.path.join(path, i)
  dataset = h5py.File(path_read, "r")
  i = i.split('.')
  set_x_orig = dataset["x_"+i[0]].shape[0]
  set_y_orig = dataset["y_"+i[0]].shape[0]
  print(set_x_orig)
  print(set_y_orig)

#循环读取每个标签集下的所有图片
def load_dataset(path_name,data_path):
 images = []
 labels = []
 train_images = []
 valid_images = []
 train_labels = []
 valid_labels = []
 counter = 0
 allpath = os.listdir(path_name)
 nb_classes = len(allpath)
 print("label_num: ",nb_classes)
 
 for child_dir in allpath:
 child_path = os.path.join(path_name, child_dir)
 for dir_image in os.listdir(child_path):
  if dir_image.endswith('.jpg'):
  img = cv2.imread(os.path.join(child_path, dir_image))  
  image = misc.imresize(img, (IMAGE_SIZE, IMAGE_SIZE), interp='bilinear')
  #resized_img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  images.append(image)
  labels.append(counter)
   
 images = np.array(images) 
 t_images, v_images, t_labels ,v_labels = split_dataset(images, labels) 
 for i in range(len(t_images)):
  train_images.append(t_images[i])
  train_labels.append(t_labels[i]) 
 for j in range(len(v_images)):
  valid_images.append(v_images[j])
  valid_labels.append(v_labels[j])
 if counter%50== 49:
  print( counter+1 , "is read to the memory!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  
 images = []
 labels = [] 
 
 if ((counter % 4316 == 4315) or (counter == nb_classes - 1)): 
  print("train_images num: ", len(train_images), "  ", "valid_images num: ",len(valid_images)) 
  print("start write images and labels data...................................................................")  
  num = counter // 5000
  dirs = data_path + "/" + "h5_" + str(num - 1)
  if not os.path.exists(dirs):
  os.makedirs(dirs)
  data2h5(dirs, train_images, valid_images, train_labels ,valid_labels)
  #read_dataset(dirs)
  print("File HDF5_%d "%num, " id done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
  train_images = []
  valid_images = []
  train_labels = []
  valid_labels = [] 
 counter = counter + 1 
 print("All File HDF5 done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 
 read_dataset(data_path) 

#读取训练数据集的文件夹,把他们的名字返回给一个list
def read_name_list(path_name):
 name_list = []
 for child_dir in os.listdir(path_name):
 name_list.append(child_dir)
 return name_list

if __name__ == '__main__':
 path = "data"
 data_path = "data_hdf5_half"
 if not os.path.exists(data_path):
 os.makedirs(data_path)
 load_dataset(path,data_path)

以上这篇完美解决keras 读取多个hdf5文件进行训练的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python遍历 truple list dictionary的几种方法总结
Sep 11 Python
1分钟快速生成用于网页内容提取的xslt
Feb 23 Python
python针对excel的操作技巧
Mar 13 Python
Python读取Word(.docx)正文信息的方法
Mar 15 Python
python os.path模块常用方法实例详解
Sep 16 Python
学习和使用python的13个理由
Jul 30 Python
python字符串格式化方式解析
Oct 19 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
python基于opencv检测程序运行效率
Dec 28 Python
python 爬取疫情数据的源码
Feb 09 Python
Python获取二维数组的行列数的2种方法
Feb 11 Python
python数据分析之单因素分析线性拟合及地理编码
Jun 25 Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
python怎么判断素数
Jul 01 #Python
You might like
探讨:如何编写PHP扩展
2013/06/13 PHP
php图像处理函数大全(推荐收藏)
2013/07/11 PHP
php中Snoopy类用法实例
2015/06/19 PHP
php无限级分类实现评论及回复功能
2019/02/18 PHP
laravel在中间件内生成参数并且传递到控制器中的2种姿势
2019/10/15 PHP
PHP实现微信提现功能(微信商城)
2019/11/21 PHP
检测是否已安装 .NET Framework 3.5的js脚本
2009/02/14 Javascript
可以用来调试JavaScript错误的解决方案
2010/08/07 Javascript
Extjs中通过Tree加载右侧TabPanel具体实现
2013/05/05 Javascript
输入自动提示搜索提示功能的使用说明:sugggestion.txt
2013/09/02 Javascript
javascript如何操作HTML下拉列表标签
2015/08/20 Javascript
javascript实现查找数组中最大值方法汇总
2016/02/13 Javascript
省市区三级联动jquery实现代码
2020/04/15 Javascript
node文件批量重命名的方法示例
2017/10/23 Javascript
如何阻止小程序遮罩层下方图层滚动
2019/09/05 Javascript
layui select 禁止点击的实现方法
2019/09/05 Javascript
viewer.js一个强大的基于jQuery的图像查看插件(支持旋转、缩放)
2020/04/01 jQuery
[11:44]Ti9 OG夺冠时刻
2019/08/25 DOTA
深入Python函数编程的一些特性
2015/04/13 Python
Python实现多线程抓取妹子图
2015/08/08 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
2019/02/25 Python
python分布式计算dispy的使用详解
2019/12/22 Python
Python常用编译器原理及特点解析
2020/03/23 Python
python实现图像高斯金字塔的示例代码
2020/12/11 Python
Orvis官网:自1856年以来,优质服装、飞钓装备等
2018/12/17 全球购物
异常和异常类的概念
2014/09/12 面试题
物流专业求职计划书
2014/01/10 职场文书
区优秀教师事迹材料
2014/02/10 职场文书
艺术教育实施方案
2014/05/03 职场文书
团员个人年度总结
2015/02/26 职场文书
银行实习推荐信
2015/03/27 职场文书
2016年“5.12”护士节慰问信
2015/11/30 职场文书
导游词之安徽九华山
2019/09/18 职场文书
Nginx+Tomcat实现负载均衡、动静分离的原理解析
2021/03/31 Servers
SQL实战演练之网上商城数据库商品类别数据操作
2021/10/24 MySQL
如何让你的Nginx支持分布式追踪详解
2022/07/07 Servers