浅谈keras通过model.fit_generator训练模型(节省内存)


Posted in Python onJune 17, 2020

前言

前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator, 
steps_per_epoch=None, 
epochs=1, 
verbose=1, 
callbacks=None, 
validation_data=None, 
validation_steps=None, 
class_weight=None, 
max_queue_size=10, 
workers=1, 
use_multiprocessing=False, 
shuffle=True, 
initial_epoch=0)

参数:

generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。

steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。

epochs:即我们训练的迭代次数。

verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行

callbacks:在训练时调用的一系列回调函数。

validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。

validation_steps:和前面的steps_per_epoch类似。

class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)

max_queue_size:整数。生成器队列的最大尺寸。默认为10.

workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。

shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

def process_x(path):
 img = Image.open(path)
 img = img.resize((96,96))
 img = img.convert('RGB')
 img = np.array(img)

 img = np.asarray(img, np.float32) / 255.0
 #也可以进行进行一些数据数据增强的处理
 return img

count =1
def generate_arrays_from_file(x_y):
 #x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是我们的独热化后的标签

 global count
 batch_size = 8
 while 1:
  batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
  batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]

  batch_x = np.array([process_x(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print("count:"+str(count))
  count = count+1
  yield (batch_x, batch_y)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)

在理解上面代码之前我们需要首先了解yield的用法。

yield关键字:

我们先通过一个例子看一下yield的用法:

def foo():
 print("starting...")
 while True:
  res = yield 4
  print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))

运行结果:

starting...
4
----------
res: None
4

带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。

然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。

所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。

现在看到上面的示例代码:

generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:

['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]

至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。

2.2使用Sequence实现generator

示例代码:

class BaseSequence(Sequence):
 """
 基础的数据流生成器,每次迭代返回一个batch
 BaseSequence可直接用于fit_generator的generator参数
 fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
 而且能保证在多进程下的一个epoch中不会重复取相同的样本
 """
 def __init__(self, img_paths, labels, batch_size, img_size):
  #np.hstack在水平方向上平铺
  self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
  self.batch_size = batch_size
  self.img_size = img_size

 def __len__(self):
  #math.ceil表示向上取整
  #调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数
  return math.ceil(len(self.x_y) / self.batch_size)

 def preprocess_img(self, img_path):

  img = Image.open(img_path)
  resize_scale = self.img_size[0] / max(img.size[:2])
  img = img.resize((self.img_size[0], self.img_size[0]))
  img = img.convert('RGB')
  img = np.array(img)

  # 数据归一化
  img = np.asarray(img, np.float32) / 255.0
  return img

 def __getitem__(self, idx):
  batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
  batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
  batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
  batch_y = np.array(batch_y).astype(np.float32)
  print(batch_x.shape)
  return batch_x, batch_y
 #重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。
 def on_epoch_end(self):
  #每次迭代后重新打乱训练集数据
  np.random.shuffle(self.x_y)

在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。

举个例子说明一下getitem的作用:

class Animal:
 def __init__(self, animal_list):
  self.animals_name = animal_list

 def __getitem__(self, index):
  return self.animals_name[index]

animals = Animal(["dog","cat","fish"])
for animal in animals:
 print(animal)

输出结果:

dog
cat
fish

并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。

以上这篇浅谈keras通过model.fit_generator训练模型(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基础之函数用法实例详解
Sep 10 Python
Python中的闭包总结
Sep 18 Python
python递归查询菜单并转换成json实例
Mar 27 Python
pandas按若干个列的组合条件筛选数据的方法
Apr 11 Python
Python中数组,列表:冒号的灵活用法介绍(np数组,列表倒序)
Apr 18 Python
python+selenium打印当前页面的titl和url方法
Jun 22 Python
对python创建及引用动态变量名的示例讲解
Nov 10 Python
Python 实现子类获取父类的类成员方法
Jan 11 Python
Python设计模式之策略模式实例详解
Jan 21 Python
Python 隐藏输入密码时屏幕回显的实例
Feb 19 Python
浅析PEP572: 海象运算符
Oct 15 Python
Python configparser模块配置文件过程解析
Mar 03 Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
python实例化对象的具体方法
Jun 17 #Python
python和php学习哪个更有发展
Jun 17 #Python
python中线程和进程有何区别
Jun 17 #Python
You might like
php中如何使对象可以像数组一样进行foreach循环
2013/08/09 PHP
PHP的一个完美GIF等比缩放类,附带去除缩放黑背景
2014/04/01 PHP
php实现用于验证所有类型的信用卡类
2015/03/24 PHP
PHP实现根据数组的值进行分组的方法
2017/04/20 PHP
javascript 鼠标拖动图标技术
2010/02/07 Javascript
jQuery 表单验证扩展(三)
2010/10/20 Javascript
js控制href内容的连接内容的变化示例
2014/04/30 Javascript
href下载文件根据id取url并下载
2014/05/28 Javascript
JQuery记住用户名密码实现下次自动登录功能
2015/04/27 Javascript
jQuery中(function($){})(jQuery)详解
2015/07/15 Javascript
jQuery 3.0中存在问题及解决办法
2016/07/15 Javascript
微信小程序 获取设备信息 API实例详解
2016/10/02 Javascript
vue2 如何实现div contenteditable=“true”(类似于v-model)的效果
2017/02/08 Javascript
利用js编写网页进度条效果
2017/10/08 Javascript
react 实现页面代码分割、按需加载的方法
2018/04/03 Javascript
Vue Promise的axios请求封装详解
2018/08/13 Javascript
TensorFlow实现Softmax回归模型
2018/03/09 Python
Pandas之Dropna滤除缺失数据的实现方法
2019/06/25 Python
Python笔记之代理模式
2019/11/20 Python
python:动态路由的Flask程序代码
2019/11/22 Python
pytorch实现对输入超过三通道的数据进行训练
2020/01/15 Python
Python中 Global和Nonlocal的用法详解
2020/01/20 Python
tensorflow模型保存、加载之变量重命名实例
2020/01/21 Python
Python requests模块session代码实例
2020/04/14 Python
django模型类中,null=True,blank=True用法说明
2020/07/09 Python
PyCharm vs VSCode,作为python开发者,你更倾向哪种IDE呢?
2020/08/17 Python
CSS实现圆形放大镜狙击镜效果 只有圆圈里的放大
2012/12/10 HTML / CSS
Paradigit比利时电脑卖场:购买笔记本、电脑、平板和外围设备
2016/11/28 全球购物
数据库基础的一些面试题
2012/02/25 面试题
高中生学习总结的自我评价范文
2013/10/13 职场文书
计算机专业毕业生求职信分享
2013/12/24 职场文书
2015年消费者权益日活动总结
2015/02/09 职场文书
交心谈心活动总结
2015/05/11 职场文书
亲戚关系证明
2015/06/24 职场文书
大学生社区义工服务心得体会
2016/01/22 职场文书
用Python selenium实现淘宝抢单机器人
2021/06/18 Python