keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的Numeric包和Numarray包使用教程
Apr 13 Python
Python3.x版本中新的字符串格式化方法
Apr 24 Python
python开发之文件操作用法实例
Nov 13 Python
python生成二维码的实例详解
Oct 29 Python
如何使用 Pylint 来规范 Python 代码风格(来自IBM)
Apr 06 Python
python中使用psutil查看内存占用的情况
Jun 11 Python
Python Django框架单元测试之文件上传测试示例
May 17 Python
python3获取当前目录的实现方法
Jul 29 Python
python 三元运算符使用解析
Sep 16 Python
Python实现大数据收集至excel的思路详解
Jan 03 Python
如何定义TensorFlow输入节点
Jan 23 Python
python实现吃苹果小游戏
Mar 21 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
PHP的范围解析操作符(::)的含义分析说明
2011/07/03 PHP
php禁用函数设置及查看方法详解
2016/07/25 PHP
PHP关于foreach复制知识点总结
2019/01/28 PHP
基于jquery实现拆分姓名的方法(纯JS版)
2013/05/08 Javascript
jquery 图片缩放拖动的简单实例
2014/01/08 Javascript
JQuery中$.each 和$(selector).each()的区别详解
2015/03/13 Javascript
js实现Select列表内容自动滚动效果代码
2015/08/20 Javascript
JS获取input file绝对路径的方法(推荐)
2016/08/02 Javascript
vue-router路由与页面间导航实例解析
2017/11/07 Javascript
微信小程序实现图片上传、删除和预览功能的方法
2017/12/18 Javascript
在小程序开发中使用npm的方法
2018/10/17 Javascript
JQueryDOM之样式操作
2019/03/27 jQuery
JS中的防抖与节流及作用详解
2019/04/01 Javascript
js JSON.stringify()基础详解
2019/06/19 Javascript
Vue 的 v-model用法实例
2020/11/23 Vue.js
[02:07]DOTA2新英雄展现中国元素,完美“圣典”亮相央视
2016/12/19 DOTA
[01:05:00]2018国际邀请赛 表演赛 Pain vs OpenAI
2018/08/24 DOTA
Python实现抓取网页并且解析的实例
2014/09/20 Python
Python3字符串学习教程
2015/08/20 Python
python从入门到精通(DAY 1)
2015/12/20 Python
unittest+coverage单元测试代码覆盖操作实例详解
2018/04/04 Python
详解python播放音频的三种方法
2019/09/23 Python
python isinstance函数用法详解
2020/02/13 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
2020/02/29 Python
pyautogui自动化控制鼠标和键盘操作的步骤
2020/04/01 Python
Python内置函数locals和globals对比
2020/04/28 Python
python3通过qq邮箱发送邮件以及附件
2020/05/20 Python
localStorage的过期时间设置的方法详解
2018/11/26 HTML / CSS
Tarte Cosmetics官网:美国最受欢迎的化妆品公司之一
2017/08/24 全球购物
阿迪达斯法国官方网站:adidas法国
2018/03/20 全球购物
初级软件工程师面试题 Junior Software Engineer Interview
2015/02/15 面试题
在教室放鞭炮的检讨书
2014/09/28 职场文书
2014年销售工作总结与计划
2014/12/01 职场文书
MySQL基础(二)
2021/04/05 MySQL
Pillow图像处理库安装及使用
2022/04/12 Python
3050和2060哪个好 性能差多少 差距有多大 谁更有性价比
2022/06/17 数码科技