keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你如何在Django 1.6中正确使用 Signal
Jun 22 Python
python实现线程池的方法
Jun 30 Python
Python实现快速多线程ping的方法
Jul 15 Python
Python中工作日类库Busines Holiday的介绍与使用
Jul 06 Python
学习Python3 Dlib19.7进行人脸面部识别
Jan 24 Python
Python基于xlrd模块操作Excel的方法示例
Jun 21 Python
pygame游戏之旅 载入小车图片、更新窗口
Nov 20 Python
对Python3中bytes和HexStr之间的转换详解
Dec 04 Python
python 实现将txt文件多行合并为一行并将中间的空格去掉方法
Dec 20 Python
利用python list完成最简单的DB连接池方法
Aug 09 Python
自适应线性神经网络Adaline的python实现详解
Sep 30 Python
Python函数参数定义及传递方式解析
Jun 10 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
php并发加锁示例
2016/10/17 PHP
详谈配置phpstorm完美支持Codeigniter(CI)代码自动完成(代码提示)
2017/04/07 PHP
laravel框架之数据库查出来的对象实现转化为数组
2019/10/23 PHP
Javascript 闭包引起的IE内存泄露分析
2012/05/23 Javascript
用Jquery重写windows.alert方法实现思路
2013/04/03 Javascript
JavaScript设置IFrame高度自适应(兼容各主流浏览器)
2013/06/05 Javascript
Jquery树插件zTree用法入门教程
2015/02/17 Javascript
jquery validate.js表单验证入门实例(附源码)
2015/11/10 Javascript
基于jQuery实现仿QQ空间送礼物功能代码
2016/05/24 Javascript
浅谈jquery.form.js的ajaxSubmit和ajaxForm的使用
2016/09/09 Javascript
three.js快速入门【推荐】
2017/01/21 Javascript
javascript验证香港身份证的格式或真实性
2017/02/07 Javascript
Angularjs按需查询实例代码
2017/10/30 Javascript
nodejs 如何手动实现服务器
2018/08/20 NodeJs
微信小程序系列之自定义顶部导航功能
2019/05/21 Javascript
微信小程序实现吸顶效果
2020/01/08 Javascript
vue页面引入three.js实现3d动画场景操作
2020/08/10 Javascript
跟老齐学Python之for循环语句
2014/10/02 Python
Python的Django框架中的Context使用
2015/07/15 Python
Python实现string字符串连接的方法总结【8种方式】
2018/07/06 Python
Pandas过滤dataframe中包含特定字符串的数据方法
2018/11/07 Python
python单线程下实现多个socket并发过程详解
2019/07/27 Python
Python numpy多维数组实现原理详解
2020/03/10 Python
python list等分并从等分的子集中随机选取一个数
2020/11/16 Python
基于python的opencv图像处理实现对斑马线的检测示例
2020/11/29 Python
曼联官方网上商店:Manchester United Direct
2017/07/28 全球购物
高中毕业生自我鉴定
2013/11/03 职场文书
蓝颜请假条
2014/04/11 职场文书
村容村貌整治方案
2014/05/21 职场文书
平安家庭示范户事迹
2014/06/02 职场文书
论群众路线学习心得体会
2014/10/31 职场文书
大三学生英语考试作弊检讨书
2015/01/01 职场文书
被告代理词范文
2015/05/25 职场文书
大学生干部培训心得体会
2016/01/06 职场文书
详解MySQL的Seconds_Behind_Master
2021/05/18 MySQL
Java设计模式之享元模式示例详解
2022/03/03 Java/Android