浅谈keras通过model.fit_generator训练模型(节省内存)


Posted in Python onJune 17, 2020

前言

前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator, 
steps_per_epoch=None, 
epochs=1, 
verbose=1, 
callbacks=None, 
validation_data=None, 
validation_steps=None, 
class_weight=None, 
max_queue_size=10, 
workers=1, 
use_multiprocessing=False, 
shuffle=True, 
initial_epoch=0)

参数:

generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。

steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。

epochs:即我们训练的迭代次数。

verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行

callbacks:在训练时调用的一系列回调函数。

validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。

validation_steps:和前面的steps_per_epoch类似。

class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)

max_queue_size:整数。生成器队列的最大尺寸。默认为10.

workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。

shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

def process_x(path):
 img = Image.open(path)
 img = img.resize((96,96))
 img = img.convert('RGB')
 img = np.array(img)

 img = np.asarray(img, np.float32) / 255.0
 #也可以进行进行一些数据数据增强的处理
 return img

count =1
def generate_arrays_from_file(x_y):
 #x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是我们的独热化后的标签

 global count
 batch_size = 8
 while 1:
  batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
  batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]

  batch_x = np.array([process_x(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print("count:"+str(count))
  count = count+1
  yield (batch_x, batch_y)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)

在理解上面代码之前我们需要首先了解yield的用法。

yield关键字:

我们先通过一个例子看一下yield的用法:

def foo():
 print("starting...")
 while True:
  res = yield 4
  print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))

运行结果:

starting...
4
----------
res: None
4

带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。

然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。

所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。

现在看到上面的示例代码:

generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:

['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]

至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。

2.2使用Sequence实现generator

示例代码:

class BaseSequence(Sequence):
 """
 基础的数据流生成器,每次迭代返回一个batch
 BaseSequence可直接用于fit_generator的generator参数
 fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
 而且能保证在多进程下的一个epoch中不会重复取相同的样本
 """
 def __init__(self, img_paths, labels, batch_size, img_size):
  #np.hstack在水平方向上平铺
  self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
  self.batch_size = batch_size
  self.img_size = img_size

 def __len__(self):
  #math.ceil表示向上取整
  #调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数
  return math.ceil(len(self.x_y) / self.batch_size)

 def preprocess_img(self, img_path):

  img = Image.open(img_path)
  resize_scale = self.img_size[0] / max(img.size[:2])
  img = img.resize((self.img_size[0], self.img_size[0]))
  img = img.convert('RGB')
  img = np.array(img)

  # 数据归一化
  img = np.asarray(img, np.float32) / 255.0
  return img

 def __getitem__(self, idx):
  batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
  batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
  batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print(batch_x.shape)
  return batch_x, batch_y
 #重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。
 def on_epoch_end(self):
  #每次迭代后重新打乱训练集数据
  np.random.shuffle(self.x_y)

在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。

举个例子说明一下getitem的作用:

class Animal:
 def __init__(self, animal_list):
  self.animals_name = animal_list

 def __getitem__(self, index):
  return self.animals_name[index]

animals = Animal(["dog","cat","fish"])
for animal in animals:
 print(animal)

输出结果:

dog
cat
fish

并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。

以上这篇浅谈keras通过model.fit_generator训练模型(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的内存泄漏及gc模块的使用分析
Jul 16 Python
Python中的自定义函数学习笔记
Sep 23 Python
基于python脚本实现软件的注册功能(机器码+注册码机制)
Oct 09 Python
python常见排序算法基础教程
Apr 13 Python
有关Python的22个编程技巧
Aug 29 Python
pygame游戏之旅 添加游戏暂停功能
Nov 21 Python
python 内置模块详解
Jan 01 Python
Python中的pathlib.Path为什么不继承str详解
Jun 23 Python
python处理自动化任务之同时批量修改word里面的内容的方法
Aug 23 Python
利用Python小工具实现3秒钟将视频转换为音频
Oct 29 Python
selenium+Chrome滑动验证码破解二(某某网站)
Dec 17 Python
Python扫描端口的实现
Jan 25 Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
python实例化对象的具体方法
Jun 17 #Python
python和php学习哪个更有发展
Jun 17 #Python
python中线程和进程有何区别
Jun 17 #Python
You might like
使用MaxMind 根据IP地址对访问者定位
2006/10/09 PHP
一个SQL管理员的web接口
2006/10/09 PHP
使用php判断浏览器的类型和语言的函数代码
2013/02/28 PHP
linux系统下php安装mbstring扩展的二种方法
2014/01/20 PHP
PHP实现获取文件后缀名的几种常用方法
2015/08/08 PHP
PHP使用http_build_query()构造URL字符串的方法
2016/04/02 PHP
使用PHPExcel导出Excel表
2018/09/08 PHP
无语,javascript居然支持中文(unicode)编程!
2007/04/12 Javascript
JavaScript 检测浏览器和操作系统的脚本
2008/12/26 Javascript
javascript中对Attr(dom中属性)的操作示例讲解
2013/12/02 Javascript
解析Javascript中大括号“{}”的多义性
2013/12/02 Javascript
jQuery实现当按下回车键时绑定点击事件
2014/01/28 Javascript
jquery实现简易的移动端验证表单
2015/11/08 Javascript
Bootstrap 源代码分析(未完待续)
2016/08/17 Javascript
JS刷新父窗口的几种方式小结(推荐)
2016/11/09 Javascript
Javascript中将变量转换为字符串的三种方法
2017/09/19 Javascript
Vue多系统切换实现方案
2018/06/05 Javascript
vue-cli3.0+element-ui上传组件el-upload的使用
2018/12/03 Javascript
uniapp实现横向滚动选择日期
2020/10/21 Javascript
python sort、sorted高级排序技巧
2014/11/21 Python
详解Swift中属性的声明与作用
2016/06/30 Python
python将unicode转为str的方法
2017/06/21 Python
Python如何生成树形图案
2018/01/03 Python
Python通过for循环理解迭代器和生成器实例详解
2019/02/16 Python
python实现视频分帧效果
2019/05/31 Python
Pytorch中的VGG实现修改最后一层FC
2020/01/15 Python
使用pth文件添加Python环境变量方式
2020/05/26 Python
Python3中对json格式数据的分析处理
2021/01/28 Python
Django与AJAX实现网页动态数据显示的示例代码
2021/02/24 Python
使用canvas实现黑客帝国数字雨效果
2020/01/02 HTML / CSS
S’well Bottle保温杯官网:绝缘不锈钢水瓶
2018/05/09 全球购物
Herve Leger官网:标志性绷带连衣裙等
2018/12/26 全球购物
工作评语大全
2014/04/26 职场文书
2014年控辍保学工作总结
2014/12/08 职场文书
党的群众路线教育实践活动先进个人材料
2014/12/24 职场文书
幼儿园教师节感谢信
2015/01/23 职场文书