Keras之fit_generator与train_on_batch用法


Posted in Python onJune 17, 2020

关于Keras中,当数据比较大时,不能全部载入内存,在训练的时候就需要利用train_on_batch或fit_generator进行训练了。

两者均是利用生成器,每次载入一个batch-size的数据进行训练。

那么fit_generator与train_on_batch该用哪一个呢?

train_on_batch(self, x, y, class_weight=None, sample_weight=None)

fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, class_weight=None, max_q_size=10)

推荐使用fit_generator,因为其同时可以设置 validation_data,但是采用train_on_batch也没什么问题,这个主要看个人习惯了,没有什么标准的答案。

下面是François Chollet fchollet本人给出的解答:

With fit_generator, you can use a generator for the validation data as well. In general I would recommend using fit_generator, but using train_on_batch works fine too. These methods only exist as for the sake of convenience in different use cases, there is no "correct" method.

补充知识:tf.keras中model.fit_generator()和model.fit()

首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练。

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

以给定数量的轮次(数据集上的迭代)训练模型。

参数

x: 训练数据的 Numpy 数组(如果模型只有一个输入), 或者是 Numpy 数组的列表(如果模型有多个输入)。 如果模型中的输入层被命名,你也可以传递一个字典,将输入层名称映射到 Numpy 数组。 如果从本地框架张量馈送(例如 TensorFlow 数据张量)数据,x 可以是 None(默认)。

y: 目标(标签)数据的 Numpy 数组(如果模型只有一个输出), 或者是 Numpy 数组的列表(如果模型有多个输出)。 如果模型中的输出层被命名,你也可以传递一个字典,将输出层名称映射到 Numpy 数组。 如果从本地框架张量馈送(例如 TensorFlow 数据张量)数据,y 可以是 None(默认)。

batch_size: 整数或 None。每次梯度更新的样本数。如果未指定,默认为 32。

epochs: 整数。训练模型迭代轮次。一个轮次是在整个 x 和 y 上的一轮迭代。 请注意,与 initial_epoch 一起,epochs 被理解为 「最终轮次」。模型并不是训练了 epochs 轮,而是到第 epochs 轮停止训练。

verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。

callbacks: 一系列的 keras.callbacks.Callback 实例。一系列可以在训练时使用的回调函数。 详见 callbacks。

validation_split: 0 和 1 之间的浮点数。用作验证集的训练数据的比例。 模型将分出一部分不会被训练的验证数据,并将在每一轮结束时评估这些验证数据的误差和任何其他模型指标。 验证数据是混洗之前 x 和y 数据的最后一部分样本中。

validation_data: 元组 (x_val,y_val) 或元组 (x_val,y_val,val_sample_weights), 用来评估损失,以及在每轮结束时的任何模型度量指标。 模型将不会在这个数据上进行训练。这个参数会覆盖 validation_split。

shuffle: 布尔值(是否在每轮迭代之前混洗数据)或者 字符串 (batch)。 batch 是处理 HDF5 数据限制的特殊选项,它对一个 batch 内部的数据进行混洗。 当 steps_per_epoch 非 None 时,这个参数无效。

class_weight: 可选的字典,用来映射类索引(整数)到权重(浮点)值,用于加权损失函数(仅在训练期间)。 这可能有助于告诉模型 「更多关注」来自代表性不足的类的样本。

sample_weight: 训练样本的可选 Numpy 权重数组,用于对损失函数进行加权(仅在训练期间)。 您可以传递与输入样本长度相同的平坦(1D)Numpy 数组(权重和样本之间的 1:1 映射), 或者在时序数据的情况下,可以传递尺寸为 (samples, sequence_length) 的 2D 数组,以对每个样本的每个时间步施加不同的权重。 在这种情况下,你应该确保在 compile() 中指定 sample_weight_mode=“temporal”。

initial_epoch: 整数。开始训练的轮次(有助于恢复之前的训练)。

steps_per_epoch: 整数或 None。 在声明一个轮次完成并开始下一个轮次之前的总步数(样品批次)。 使用 TensorFlow 数据张量等输入张量进行训练时,默认值 None 等于数据集中样本的数量除以 batch 的大小,如果无法确定,则为 1。

validation_steps: 只有在指定了 steps_per_epoch 时才有用。停止前要验证的总步数(批次样本)。

返回

一个 History 对象。其 History.history 属性是连续 epoch 训练损失和评估值,以及验证集损失和评估值的记录(如果适用)。

异常

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

使用 Python 生成器(或 Sequence 实例)逐批生成的数据,按批次训练模型。

生成器与模型并行运行,以提高效率。 例如,这可以让你在 CPU 上对图像进行实时数据增强,以在 GPU 上训练模型。

keras.utils.Sequence 的使用可以保证数据的顺序, 以及当 use_multiprocessing=True 时 ,保证每个输入在每个 epoch 只使用一次。

参数

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被 batch size 整除。 生成器将无限地在数据集上循环。当运行到第 steps_per_epoch 时,记一个 epoch 结束。

steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator 产生的总步数(批次样本)。 它通常应该等于你的数据集的样本数量除以批量大小。 对于 Sequence,它是可选的:如果未指定,将使用len(generator) 作为步数。

epochs: 整数。训练模型的迭代总轮数。一个 epoch 是对所提供的整个数据的一轮迭代,如 steps_per_epoch 所定义。注意,与 initial_epoch 一起使用,epoch 应被理解为「最后一轮」。模型没有经历由 epochs 给出的多次迭代的训练,而仅仅是直到达到索引 epoch 的轮次。

verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。

callbacks: keras.callbacks.Callback 实例的列表。在训练时调用的一系列回调函数。

validation_data: 它可以是以下之一:

验证数据的生成器或 Sequence 实例

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

在每个 epoch 结束时评估损失和任何模型指标。该模型不会对此数据进行训练。

validation_steps: 仅当 validation_data 是一个生成器时才可用。 在停止前 generator 生成的总步数(样本批数)。 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。

class_weight: 可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。

max_queue_size: 整数。生成器队列的最大尺寸。 如未指定,max_queue_size 将默认为 10。

workers: 整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。

use_multiprocessing: 布尔值。如果 True,则使用基于进程的多线程。 如未指定, use_multiprocessing 将默认为 False。 请注意,由于此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。

shuffle: 是否在每轮迭代之前打乱 batch 的顺序。 只能与 Sequence (keras.utils.Sequence) 实例同用。

initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。

返回

一个 History 对象。其 History.history 属性是连续 epoch 训练损失和评估值,以及验证集损失和评估值的记录(如果适用)。

异常

ValueError: 如果生成器生成的数据格式不正确。

model.fit_generator(
  train_generator,
  steps_per_epoch=10, # 100
  validation_steps=1, # 50
  epochs=600, # 20个周期
  validation_data=validation_generator)

以上这篇Keras之fit_generator与train_on_batch用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作字符串之rstrip()方法的使用
May 19 Python
django中的HTML控件及参数传递方法
Mar 20 Python
详解Python的hasattr() getattr() setattr() 函数使用方法
Jul 09 Python
python生成九宫格图片
Nov 19 Python
PyQt5 实现给窗口设置背景图片的方法
Jun 13 Python
python 计算数据偏差和峰度的方法
Jun 29 Python
Python generator生成器和yield表达式详解
Aug 08 Python
python递归下载文件夹下所有文件
Aug 31 Python
python字符串下标与切片及使用方法
Feb 13 Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 Python
使用py-spy解决scrapy卡死的问题方法
Sep 29 Python
使用Python制作一个数据预处理小工具(多种操作一键完成)
Feb 07 Python
基于Keras的格式化输出Loss实现方式
Jun 17 #Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
You might like
php笔记之常用文件操作
2010/10/12 PHP
PHP基于SMTP协议实现邮件发送实例代码
2017/04/27 PHP
jquery 图片截取工具jquery.imagecropper.js
2010/04/09 Javascript
NodeJS url验证(url-valid)的使用方法
2013/11/18 NodeJs
js实现有时间限制消失的图片方法
2015/02/27 Javascript
js实现鼠标划过给div加透明度的方法
2015/05/25 Javascript
jQuery下拉美化搜索表单效果代码分享
2015/08/25 Javascript
javascript实现简单计算器效果【推荐】
2016/04/19 Javascript
Vue.JS入门教程之处理表单
2016/12/01 Javascript
JS实现数组去重方法总结(六种方法)
2017/07/14 Javascript
JS实现的倒计时恢复按钮点击功能【可用于协议阅读倒计时】
2018/04/19 Javascript
微信小程序实现自上而下字幕滚动
2018/07/14 Javascript
JS实现联想、自动补齐国家或地区名称的功能
2020/07/07 Javascript
解决vue-photo-preview 异步图片放大失效的问题
2020/07/29 Javascript
[03:56]还原FTP电影首映式 DOTA2群星拼出遗迹世界
2014/03/26 DOTA
[02:56]《DAC最前线》之国外战队抵达上海备战亚洲邀请赛
2015/01/28 DOTA
python 布尔操作实现代码
2013/03/23 Python
Python OpenCV获取视频的方法
2018/02/28 Python
Django中日期处理注意事项与自定义时间格式转换详解
2018/08/06 Python
对Python 3.2 迭代器的next函数实例讲解
2018/10/18 Python
Python数据类型之List列表实例详解
2019/05/08 Python
python3.7 的新特性详解
2019/07/25 Python
详解Python 字符串相似性的几种度量方法
2019/08/29 Python
tensorflow实现残差网络方式(mnist数据集)
2020/05/26 Python
美国婴童服装市场上的领先品牌:Carter’s
2018/02/08 全球购物
巴西儿童时尚购物网站:Dinda
2019/08/14 全球购物
德国W家官网,可直邮中国的母婴商城:Windeln.de
2021/03/03 全球购物
问卷调查计划书
2014/01/10 职场文书
高中同学聚会邀请函
2014/01/11 职场文书
保安岗位职责
2014/02/21 职场文书
家庭财产分割协议范文
2014/11/24 职场文书
2015年小学图书室工作总结
2015/05/18 职场文书
医疗纠纷调解协议书
2015/08/06 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
《原神》新角色演示“神里绫人:林隐泓洄” 宠妹狂魔
2022/04/03 其他游戏
React四级菜单的实现
2022/04/08 Javascript