基于keras中的回调函数用法说明


Posted in Python onJune 17, 2020

keras训练

fit(
 self, 
 x, 
 y, 
 batch_size=32, 
 nb_epoch=10, 
 verbose=1, 
 callbacks=[], 
 validation_split=0.0, 
 validation_data=None, 
 shuffle=True, 
 class_weight=None, 
 sample_weight=None
)

1. x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。

2. y:标签,numpy array。如果模型有多个输出,可以传入一个numpy array的list。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。

3. batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。

4. nb_epoch:整数,训练的轮数,训练数据将会被遍历nb_epoch次。Keras中nb开头的变量均为"number of"的意思

5. verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

6. callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数

7. validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

8. validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。

9. shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。

10. class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。

11. sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'。

fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。

保存模型结构、训练出来的权重、及优化器状态

keras 的 callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

keras.callbacks.ModelCheckpoint(
 filepath, 
 monitor='val_loss', 
 verbose=0, 
 save_best_only=False, 
 save_weights_only=False, 
 mode='auto', 
 period=1
)

1. filename:字符串,保存模型的路径

2. monitor:需要监视的值

3. verbose:信息展示模式,0或1

4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

5. mode:‘auto',‘min',‘max'之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

7. period:CheckPoint之间的间隔的epoch数

当验证损失不再继续降低时,如何中断训练?当监测值不再改善时中止训练

用EarlyStopping回调函数

from keras.callbacksimport EarlyStopping 

keras.callbacks.EarlyStopping(
 monitor='val_loss', 
 patience=0, 
 verbose=0, 
 mode='auto'
)

model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])

1. monitor:需要监视的量

2. patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

3. verbose:信息展示模式

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

学习率动态调整1

keras.callbacks.LearningRateScheduler(schedule)

schedule:函数,该函数以epoch号为参数(从0算起的整数),返回一个新学习率(浮点数)

也可以让keras自动调整学习率

keras.callbacks.ReduceLROnPlateau(
 monitor='val_loss', 
 factor=0.1, 
 patience=10, 
 verbose=0, 
 mode='auto', 
 epsilon=0.0001, 
 cooldown=0, 
 min_lr=0
)

1. monitor:被监测的量

2. factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少

3. patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。

5. epsilon:阈值,用来确定是否进入检测值的“平原区”

6. cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作

7. min_lr:学习率的下限

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果

学习率动态2

def step_decay(epoch):
 initial_lrate = 0.01
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
 return lrate
lrate = LearningRateScheduler(step_decay)
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])

如何记录每一次epoch的训练/验证损失/准确度?

Model.fit函数会返回一个 History 回调,该回调有一个属性history包含一个封装有连续损失/准确的lists。代码如下:

hist = model.fit(X, y,validation_split=0.2)
print(hist.history)

Keras输出的loss,val这些值如何保存到文本中去

Keras中的fit函数会返回一个History对象,它的History.history属性会把之前的那些值全保存在里面,如果有验证集的话,也包含了验证集的这些指标变化情况,具体写法

hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
with open('log_sgd_big_32.txt','w') as f:
 f.write(str(hist.history))

示例,多个回调函数用逗号隔开

# checkpoint
checkpointer = ModelCheckpoint(filepath="./checkpoint.hdf5", verbose=1)
# learning rate adjust dynamic
lrate = ReduceLROnPlateau(min_lr=0.00001)

answer.compile(optimizer='rmsprop', loss='categorical_crossentropy',
    metrics=['accuracy'])
# Note: you could use a Graph model to avoid repeat the input twice
answer.fit(
 [inputs_train, queries_train, inputs_train], answers_train,
 batch_size=32,
 nb_epoch=5000,
 validation_data=([inputs_test, queries_test, inputs_test], answers_test),
 callbacks=[checkpointer, lrate]
)

keras回调函数中的Tensorboard

keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, 
   write_graph=True, write_images=True)

tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
...
model.fit(...inputs and parameters..., callbacks=[tbCallBack])
tensorboard --logdir path_to_current_dir/Graph

或者

from keras.callbacks import TensorBoard

tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
       write_graph=True, write_images=False)
# define model
model.fit(X_train, Y_train,
   batch_size=batch_size,
   epochs=nb_epoch,
   validation_data=(X_test, Y_test),
   shuffle=True,
   callbacks=[tensorboard])

补充知识:Keras中的回调函数(callback)的使用与介绍

以前我在训练的时候,都是直接设定一个比较大的epoch,跑完所有的epoch之后再根据数据去调整模型与参数。这样做会比较耗时,例如说训练在某一个epoch开始已经过拟合了,后面继续训练意义就不大了。

在书上看到的callback函数很好的解决了这个问题,它能够监测训练过程中的loss或者acc这些指标,一旦观察到损失不再改善之后,就可以中止训练,节省时间。下面记录一下

介绍:

(选自《python深度学习》)

回调函数(callback)是在调用fit时传入模型的一个对象,它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。

部分回调函数:

1.ModelCheckpoint与EarlyStopping

监控目标若在指定轮数内不再改善,可利用EarlyStopping来中断训练。

可配合ModelCheckpoint使用,该回调函数可不断地保存模型,亦可以只保存某一epoch最佳性能模型

import keras
callbacks_list=[
 keras.callbacks.EarlyStopping(
  monitor='acc',#监控精度
  patience=5,#5轮内不改善就中止
),
 keras.callbacks.ModelCheckpoint(
  filepath='C:/apple/my_model.h5',#模型保存路径
  monitor='val_loss',#检测验证集损失值
  save_best_only=True#是否只保存最佳模型
 )
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

2.ReduceLROnPlateau回调函数

如果验证损失不再改善,可以使用该回调函数来降低学习率。

import keras
 
callbacks_list=[
 keras.callbacks.ReduceLROnPlateau(
  monitor='val_loss',#监控精度
  patienece=5, # 5轮内不改善就改变
  factor=0.1#学习率变为原来的0.1
)
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

以上这篇基于keras中的回调函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
理解python多线程(python多线程简明教程)
Jun 09 Python
Python构造函数及解构函数介绍
Feb 26 Python
python中os和sys模块的区别与常用方法总结
Nov 14 Python
python+pillow绘制矩阵盖尔圆简单实例
Jan 16 Python
python 爬虫一键爬取 淘宝天猫宝贝页面主图颜色图和详情图的教程
May 22 Python
用python打印1~20的整数实例讲解
Jul 01 Python
docker-py 用Python调用Docker接口的方法
Aug 30 Python
Python实现串口通信(pyserial)过程解析
Sep 25 Python
pyqt5中动画的使用详解
Apr 01 Python
Python unittest单元测试框架及断言方法
Apr 15 Python
python自动统计zabbix系统监控覆盖率的示例代码
Apr 03 Python
python如何读取.mtx文件
Apr 22 Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
浅谈Python协程
Jun 17 #Python
使用K.function()调试keras操作
Jun 17 #Python
哪些是python中web开发框架
Jun 17 #Python
python如何处理程序无法打开
Jun 16 #Python
python模块如何查看
Jun 16 #Python
You might like
深入php 正则表达式的学习探讨
2013/06/06 PHP
基于laravel制作APP接口(API)
2016/03/15 PHP
js的event详解。
2006/09/06 Javascript
基于Jquery的简单&简陋Tabs插件代码
2010/02/09 Javascript
JS将光标聚焦在文本最后的实现代码
2014/03/28 Javascript
JavaScript继承基础讲解(原型链、借用构造函数、混合模式、原型式继承、寄生式继承、寄生组合式继承)
2014/08/16 Javascript
JavaScript跨域方法汇总
2014/10/16 Javascript
JavaScript学习心得之概述
2015/01/20 Javascript
js window对象属性和方法相关资料整理
2015/11/11 Javascript
浅析JavaScript 箭头函数 generator Date JSON
2016/05/23 Javascript
js选择器全面解析
2016/06/27 Javascript
两种JavaScript的AES加密方式(可与Java相互加解密)
2016/08/02 Javascript
Bootstrap基本插件学习笔记之折叠(22)
2016/12/08 Javascript
JS回调函数简单用法示例
2017/02/09 Javascript
JavaScript函数中的this四种绑定形式
2017/08/15 Javascript
微信小程序 MinUI组件库系列之badge徽章组件示例
2018/08/20 Javascript
js实现点赞按钮功能的实例代码
2020/03/06 Javascript
vue过滤器实现日期格式化的案例分析
2020/07/02 Javascript
vue treeselect获取当前选中项的label实例
2020/08/31 Javascript
jQuery实现动态操作table行
2020/11/23 jQuery
[02:19]DOTA选手解说齐贺岁
2018/02/11 DOTA
python3新特性函数注释Function Annotations用法分析
2016/07/28 Python
python自动识别文本编码格式代码
2019/12/26 Python
印度排名第一的蛋糕、鲜花和礼品送货:Winni
2019/08/02 全球购物
党校自我鉴定范文
2013/10/02 职场文书
三好学生个人先进事迹材料
2014/05/17 职场文书
村干部群众路线教育活动对照检查材料
2014/10/01 职场文书
村党的群众路线教育实践活动工作总结
2014/10/25 职场文书
劳动纠纷调解协议书格式
2014/11/30 职场文书
银行客户经理岗位职责
2015/04/09 职场文书
保护校园环境倡议书
2015/04/28 职场文书
2016天猫双十一广告语
2016/01/28 职场文书
2016年优秀班主任先进事迹材料
2016/02/26 职场文书
节约用水广告语60条
2019/11/14 职场文书
八年级作文之感恩
2019/11/22 职场文书
2019年12月24日平安夜祝福语集锦
2019/12/24 职场文书