keras使用Sequence类调用大规模数据集进行训练的实现


Posted in Python onJune 22, 2020

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据 
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else: 
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)
  
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用setup.py安装python包和卸载python包的方法
Nov 27 Python
python中将字典转换成其json字符串
Jul 16 Python
python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件
Aug 15 Python
Python基于有道实现英汉字典功能
Jul 25 Python
Python 结巴分词实现关键词抽取分析
Oct 21 Python
Python函数返回不定数量的值方法
Jan 22 Python
一文秒懂python读写csv xml json文件各种骚操作
Jul 04 Python
Django在admin后台集成TinyMCE富文本编辑器的例子
Aug 09 Python
python标识符命名规范原理解析
Jan 10 Python
Python通过2种方法输出带颜色字体
Mar 02 Python
TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法
Apr 19 Python
Python xlwings插入Excel图片的实现方法
Feb 26 Python
Python socket服务常用操作代码实例
Jun 22 #Python
Python如何实现后端自定义认证并实现多条件登陆
Jun 22 #Python
零基础小白多久能学会python
Jun 22 #Python
Keras-多输入多输出实例(多任务)
Jun 22 #Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
You might like
又一个php 分页类实现代码
2009/12/03 PHP
使用php shell命令合并图片的代码
2011/06/23 PHP
PHP+jQuery实现自动补全功能源码
2013/05/15 PHP
PHP 5.5 创建和验证哈希最简单的方法详解
2013/11/07 PHP
PHP获取文件相对路径的方法
2015/02/26 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
php实现通过stomp协议连接ActiveMQ操作示例
2020/02/23 PHP
符合W3C网页标准的iframe标签的使用方法
2007/07/19 Javascript
js 动态选中下拉框
2009/11/26 Javascript
javascript中的prototype属性实例分析说明
2010/08/09 Javascript
juqery 学习之三 选择器 简单 内容
2010/11/25 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
获得Javascript对象属性个数的示例代码
2013/11/21 Javascript
jquery代码实现多选、不同分享功能
2015/07/31 Javascript
jQuery插件Timelinr 实现时间轴特效
2015/10/04 Javascript
实例解析jQuery插件EasyUI最常用的表单验证规则
2015/11/29 Javascript
jqueryMobile 动态添加元素,展示刷新视图的实现方法
2016/05/28 Javascript
jQuery实现下拉菜单(内容为时间)的实时更新及图表的随动更新的方法
2016/07/07 Javascript
js中apply与call简单用法详解
2017/11/06 Javascript
WebSocket的通信过程与实现方法详解
2018/04/29 Javascript
详解Node.js中path模块的resolve()和join()方法的区别
2018/10/29 Javascript
JavaScript中关于base64的一些事
2019/05/06 Javascript
PHP webshell检查工具 python实现代码
2009/09/15 Python
Python3.4学习笔记之类型判断,异常处理,终止程序操作小结
2019/03/01 Python
django使用haystack调用Elasticsearch实现索引搜索
2019/07/24 Python
Python之关于类变量的两种赋值区别详解
2020/03/12 Python
python 30行代码实现蚂蚁森林自动偷能量
2021/02/08 Python
CSS3图片旋转特效(360/60/-360度)
2013/10/10 HTML / CSS
详解CSS3选择器的使用方法汇总
2015/11/24 HTML / CSS
应届生自我鉴定
2013/12/11 职场文书
家长会邀请书
2014/01/25 职场文书
实习协议书范本
2014/04/22 职场文书
电子专业求职信
2014/06/19 职场文书
音乐教育专业自荐信
2014/09/18 职场文书
刑事上诉状范文
2015/05/22 职场文书
使用Redis实现分布式锁的方法
2022/06/16 Redis