keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython事件驱动实例详解
Sep 28 Python
Python标准库之随机数 (math包、random包)介绍
Nov 25 Python
利用Python破解验证码实例详解
Dec 08 Python
Python基础教程之浅拷贝和深拷贝实例详解
Jul 15 Python
Python编程pygal绘图实例之XY线
Dec 09 Python
Django自定义用户认证示例详解
Mar 14 Python
Pyqt QImage 与 np array 转换方法
Jun 27 Python
Python with语句和过程抽取思想
Dec 23 Python
python实现简单井字棋小游戏
Mar 05 Python
Python try except异常捕获机制原理解析
Apr 18 Python
python基于selenium爬取斗鱼弹幕
Feb 20 Python
python字典进行运算原理及实例分享
Aug 02 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
mysql 全文搜索 技巧
2007/04/27 PHP
轻松修复Discuz!数据库
2008/05/03 PHP
PHP+MySQL 手工注入语句大全 推荐
2009/10/30 PHP
php使用function_exists判断函数可用的方法
2014/11/19 PHP
搭建Vim为自定义的PHP开发工具的一些技巧
2015/12/11 PHP
JavaScript高级程序设计(第3版)学习笔记7 js函数(上)
2012/10/11 Javascript
禁止拷贝网页内容的js代码
2014/01/22 Javascript
一个JavaScript防止表单重复提交的实例
2014/10/21 Javascript
js带点自动图片轮播幻灯片特效代码分享
2015/09/07 Javascript
jQuery获取父元素节点、子元素节点及兄弟元素节点的方法
2016/04/14 Javascript
详细分析Javascript中创建对象的四种方式
2016/08/17 Javascript
简单模拟node.js中require的加载机制
2016/10/27 Javascript
Angular.JS判断复选框checkbox是否选中并实时显示
2016/11/30 Javascript
在JS中a标签加入单击事件屏蔽href跳转页面
2016/12/16 Javascript
jQuery源码分析之sizzle选择器详解
2017/02/13 Javascript
vue实现简单loading进度条
2018/06/06 Javascript
对angularJs中$sce服务安全显示html文本的实例
2018/09/30 Javascript
Vue实现表格批量审核功能实例代码
2019/05/28 Javascript
js实现简单放大镜效果
2020/03/07 Javascript
[03:22]DOTA2超级联赛专访单车:找到属于自己的英雄
2013/06/08 DOTA
浅谈python中截取字符函数strip,lstrip,rstrip
2015/07/17 Python
Python yield与实现方法代码分析
2018/02/06 Python
Python实现的网页截图功能【PyQt4与selenium组件】
2018/07/12 Python
jenkins配置python脚本定时任务过程图解
2019/10/29 Python
使用python+whoosh实现全文检索
2019/12/09 Python
Anconda环境下Vscode安装Python的方法详解
2020/03/29 Python
香港化妆品经销商:我的公主
2016/08/05 全球购物
Elizabeth Gage官网:英国最好的珠宝设计之一
2020/09/26 全球购物
中英文自我评价语句
2013/12/20 职场文书
2014年三八妇女节活动方案
2014/02/28 职场文书
文明和谐家庭事迹材料
2014/05/18 职场文书
交通安全教育心得体会
2016/01/15 职场文书
技术转让协议书
2016/03/19 职场文书
浅谈@Value和@Bean的执行顺序问题
2021/06/16 Java/Android
未发现nvidia显卡怎么办?Win11系统中未检测到nvidia显卡解决教程
2022/04/08 数码科技
高通2023 年将发布高性能PC处理器
2022/04/29 数码科技