keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现strcmp函数功能示例
Mar 25 Python
Python自动化测试工具Splinter简介和使用实例
May 13 Python
Python常见数据结构详解
Jul 24 Python
Python使用matplotlib简单绘图示例
Feb 01 Python
深入浅析Python 中 is 语法带来的误解
May 07 Python
使用pyinstaller打包PyQt4程序遇到的问题及解决方法
Jun 24 Python
python批量修改图片尺寸,并保存指定路径的实现方法
Jul 04 Python
Python数据可视化:饼状图的实例讲解
Dec 07 Python
Django的CVB实例详解
Feb 10 Python
TensorBoard 计算图的查看方式
Feb 15 Python
Python实现EM算法实例代码
Oct 04 Python
如何使用python包中的sched事件调度器
Apr 30 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
PHP中改变图片的尺寸大小的代码
2011/07/17 PHP
PHP面向对象程序设计模拟一般面向对象语言中的方法重载(overload)示例
2019/06/13 PHP
jquery 学习之二 属性相关
2010/11/23 Javascript
Javascript中自动切换焦点实现代码
2012/12/15 Javascript
Nodejs实现的一个简单udp广播服务器、客户端
2014/09/25 NodeJs
JsRender实用入门教程
2014/10/31 Javascript
javascript实现checkbox复选框实例代码
2016/01/10 Javascript
JavaScript 弹出子窗体并返回结果到父窗体的实现代码
2016/05/28 Javascript
jQuery、zepto、js常用小技巧
2017/02/12 Javascript
80%应聘者都不及格的JS面试题
2017/03/21 Javascript
nginx配置React静态页面的方法教程
2017/11/03 Javascript
js实现数组内数据的上移和下移的实例
2017/11/14 Javascript
angularJs中json数据转换与本地存储的实例
2018/10/08 Javascript
Javascript迭代、递推、穷举、递归常用算法实例讲解
2019/02/01 Javascript
ES6 如何改变JS内置行为的代理与反射
2019/02/11 Javascript
VUE安装使用教程详解
2019/06/03 Javascript
vue 集成jTopo 处理方法
2019/08/07 Javascript
vue实现简单瀑布流布局
2020/05/28 Javascript
vue2.0实现列表数据增加和删除
2020/06/17 Javascript
[06:43]2018DOTA2国际邀请赛寻真——VGJ.Thunder
2018/08/11 DOTA
学习python (2)
2006/10/31 Python
分析Python的Django框架的运行方式及处理流程
2015/04/08 Python
python fabric使用笔记
2015/05/09 Python
调用其他python脚本文件里面的类和方法过程解析
2019/11/15 Python
python使用HTMLTestRunner导出饼图分析报告的方法
2019/12/30 Python
深入了解如何基于Python读写Kafka
2019/12/31 Python
基于Python第三方插件实现西游记章节标注汉语拼音的方法
2020/05/22 Python
通过Django Admin+HttpRunner1.5.6实现简易接口测试平台
2020/11/11 Python
Evisu官方网站:日本牛仔品牌,时尚街头设计风格
2016/12/30 全球购物
美国豪华的多品牌精品店:The Webster
2019/07/31 全球购物
static全局变量与普通的全局变量有什么区别?static局部变量和普通局部变量有什么区别?static函数与普通函数有什么区别?
2015/02/22 面试题
2014领导班子专题民主生活会对照检查材料思想汇报
2014/09/23 职场文书
2014年结对帮扶工作总结
2014/12/17 职场文书
2014年为民办实事工作总结
2014/12/20 职场文书
银行员工考核评语
2014/12/31 职场文书
Nginx 过滤静态资源文件的访问日志的实现
2021/03/31 Servers