keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量生成本地ip地址的方法
Mar 23 Python
Python中functools模块的常用函数解析
Jun 30 Python
利用python实现命令行有道词典的方法示例
Jan 31 Python
Python3.x对JSON的一些操作示例
Sep 01 Python
使用Python来开发微信功能
Jun 13 Python
Django 接收Post请求数据,并保存到数据库的实现方法
Jul 12 Python
windows下python虚拟环境virtualenv安装和使用详解
Jul 16 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 Python
基于SQLAlchemy实现操作MySQL并执行原生sql语句
Jun 10 Python
如何利用python读取micaps文件详解
Oct 18 Python
python通用数据库操作工具 pydbclib的使用简介
Dec 21 Python
Python selenium模拟网页点击爬虫交管12123违章数据
May 26 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
最贵的咖啡是怎么产生的,它的风味怎么样?
2021/03/04 新手入门
CI框架中libraries,helpers,hooks文件夹详细说明
2014/06/10 PHP
基于php实现的php代码加密解密类完整实例
2016/10/12 PHP
JS(jQuery)实现聊天接收到消息语言自动提醒功能详解【提示“您有新的消息请注意查收”】
2019/04/16 PHP
js传值 判断
2006/10/26 Javascript
jquery的Theme和Theme Switcher使用小结
2010/09/08 Javascript
jquery 延迟执行实例介绍
2013/08/20 Javascript
js 去除字符串第一位逗号的方法
2014/06/07 Javascript
浅谈javascript中replace()方法
2015/11/10 Javascript
Backbone.js框架中简单的View视图编写学习笔记
2016/02/14 Javascript
JavaScript组合模式学习要点
2016/08/26 Javascript
jquery pagination分页插件使用详解(后台struts2)
2017/01/22 Javascript
超简单的Vue.js环境搭建教程
2017/03/17 Javascript
ES6中Symbol类型用法实例详解
2017/04/06 Javascript
基于VUE移动音乐WEBAPP跨域请求失败的解决方法
2018/01/16 Javascript
对mac下nodejs 更新到最新版本的最新方法(推荐)
2018/05/17 NodeJs
JS插件clipboard.js实现一键复制粘贴功能
2020/12/04 Javascript
javascript定时器的简单应用示例【控制方块移动】
2019/06/17 Javascript
Vue-cli3.x + axios 跨域方案踩坑指北
2019/07/04 Javascript
Emberjs 通过 axios 下载文件的方法
2019/09/03 Javascript
[02:19]DOTA选手解说齐贺岁
2018/02/11 DOTA
python采用requests库模拟登录和抓取数据的简单示例
2014/07/05 Python
python实现RSA加密(解密)算法
2016/02/17 Python
python脚本监控docker容器
2016/04/27 Python
Django Admin中增加导出CSV功能过程解析
2019/09/04 Python
Python基于内置库pytesseract实现图片验证码识别功能
2020/02/24 Python
python selenium操作cookie的实现
2020/03/18 Python
推荐一些比较有用的css3新属性
2014/11/11 HTML / CSS
申请任职学生会干部自荐书范文
2014/02/13 职场文书
驾驶员培训方案
2014/05/01 职场文书
预防传染病方案
2014/06/14 职场文书
2014年司法局工作总结
2014/12/11 职场文书
幼师个人总结范文
2015/02/28 职场文书
加薪通知
2015/04/25 职场文书
Golang二维切片初始化的实现
2021/04/08 Golang
python套接字socket通信
2022/04/01 Python