keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求斐波那契数列示例分享
Feb 14 Python
详解Python迭代和迭代器
Mar 28 Python
解决python3在anaconda下安装caffe失败的问题
Jun 15 Python
理解Python中的绝对路径和相对路径
Aug 30 Python
Python和Java进行DES加密和解密的实例
Jan 09 Python
Python3 SSH远程连接服务器的方法示例
Dec 29 Python
Python设计模式之策略模式实例详解
Jan 21 Python
解决python xx.py文件点击完之后一闪而过的问题
Jun 24 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 Python
Python+unittest+DDT实现数据驱动测试
Nov 30 Python
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
Jun 21 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
SONY SRF-40W电路分析
2021/03/02 无线电
php基于自定义函数记录log日志方法
2017/07/21 PHP
javascript 三种编解码方式
2010/02/01 Javascript
JavaScript CSS修改学习第一章 查找位置
2010/02/19 Javascript
Extjs入门之动态加载树代码
2010/04/09 Javascript
点评js异步加载的4种方式
2015/12/22 Javascript
jQuery Ajax 全局调用封装实例代码详解
2016/06/02 Javascript
js 判断一组日期是否是连续的简单实例
2016/07/11 Javascript
利用Vue.js指令实现全选功能
2016/09/08 Javascript
jQuery实现的放大镜效果示例
2016/09/13 Javascript
Vue.js动态添加、删除选题的实例代码
2016/09/30 Javascript
vue 实现通过手机发送短信验证码注册功能
2018/04/19 Javascript
微信小程序自定义弹窗wcPop插件
2018/11/19 Javascript
vue-drawer-layout实现手势滑出菜单栏
2020/11/19 Vue.js
[03:32]2014DOTA2西雅图邀请赛 CIS外卡赛赛前black专访
2014/07/09 DOTA
[01:06]DOTA2亚洲邀请赛专属珍藏-荧煌之礼
2017/03/24 DOTA
简单介绍Python中利用生成器实现的并发编程
2015/05/04 Python
约瑟夫问题的Python和C++求解方法
2015/08/20 Python
Python Paramiko模块的安装与使用详解
2016/11/18 Python
Python正则表达式非贪婪、多行匹配功能示例
2017/08/08 Python
python3实现163邮箱SMTP发送邮件
2018/05/22 Python
Python找出微信上删除你好友的人脚本写法
2018/11/01 Python
使用python serial 获取所有的串口名称的实例
2019/07/02 Python
Html5实现如何在两个div元素之间拖放图像
2013/03/29 HTML / CSS
HTML5中5个简单实用的API
2014/04/28 HTML / CSS
Doyoueven官网:澳大利亚健身服饰和配饰品牌
2019/03/24 全球购物
校园十大歌手策划书
2014/02/01 职场文书
入党思想汇报怎么写
2014/04/03 职场文书
讲解员培训方案
2014/05/04 职场文书
保洁员岗位职责
2015/02/04 职场文书
小学数学教师研修感悟
2015/11/18 职场文书
JavaScript canvas实现流星特效
2021/05/20 Javascript
Spring Boot 排除某个类加载注入IOC的操作
2021/08/02 Java/Android
JavaScript 原型与原型链详情
2021/11/02 Javascript
Python如何快速找到多个字典中的公共键(key)
2022/04/29 Python
Redis 异步机制
2022/05/15 Redis