基于keras中的回调函数用法说明


Posted in Python onJune 17, 2020

keras训练

fit(
 self, 
 x, 
 y, 
 batch_size=32, 
 nb_epoch=10, 
 verbose=1, 
 callbacks=[], 
 validation_split=0.0, 
 validation_data=None, 
 shuffle=True, 
 class_weight=None, 
 sample_weight=None
)

1. x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。

2. y:标签,numpy array。如果模型有多个输出,可以传入一个numpy array的list。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。

3. batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。

4. nb_epoch:整数,训练的轮数,训练数据将会被遍历nb_epoch次。Keras中nb开头的变量均为"number of"的意思

5. verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

6. callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数

7. validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

8. validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。

9. shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。

10. class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。

11. sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'。

fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。

保存模型结构、训练出来的权重、及优化器状态

keras 的 callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

keras.callbacks.ModelCheckpoint(
 filepath, 
 monitor='val_loss', 
 verbose=0, 
 save_best_only=False, 
 save_weights_only=False, 
 mode='auto', 
 period=1
)

1. filename:字符串,保存模型的路径

2. monitor:需要监视的值

3. verbose:信息展示模式,0或1

4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

5. mode:‘auto',‘min',‘max'之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

7. period:CheckPoint之间的间隔的epoch数

当验证损失不再继续降低时,如何中断训练?当监测值不再改善时中止训练

用EarlyStopping回调函数

from keras.callbacksimport EarlyStopping 

keras.callbacks.EarlyStopping(
 monitor='val_loss', 
 patience=0, 
 verbose=0, 
 mode='auto'
)

model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])

1. monitor:需要监视的量

2. patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

3. verbose:信息展示模式

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

学习率动态调整1

keras.callbacks.LearningRateScheduler(schedule)

schedule:函数,该函数以epoch号为参数(从0算起的整数),返回一个新学习率(浮点数)

也可以让keras自动调整学习率

keras.callbacks.ReduceLROnPlateau(
 monitor='val_loss', 
 factor=0.1, 
 patience=10, 
 verbose=0, 
 mode='auto', 
 epsilon=0.0001, 
 cooldown=0, 
 min_lr=0
)

1. monitor:被监测的量

2. factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少

3. patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。

5. epsilon:阈值,用来确定是否进入检测值的“平原区”

6. cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作

7. min_lr:学习率的下限

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果

学习率动态2

def step_decay(epoch):
 initial_lrate = 0.01
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
 return lrate
lrate = LearningRateScheduler(step_decay)
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])

如何记录每一次epoch的训练/验证损失/准确度?

Model.fit函数会返回一个 History 回调,该回调有一个属性history包含一个封装有连续损失/准确的lists。代码如下:

hist = model.fit(X, y,validation_split=0.2)
print(hist.history)

Keras输出的loss,val这些值如何保存到文本中去

Keras中的fit函数会返回一个History对象,它的History.history属性会把之前的那些值全保存在里面,如果有验证集的话,也包含了验证集的这些指标变化情况,具体写法

hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
with open('log_sgd_big_32.txt','w') as f:
 f.write(str(hist.history))

示例,多个回调函数用逗号隔开

# checkpoint
checkpointer = ModelCheckpoint(filepath="./checkpoint.hdf5", verbose=1)
# learning rate adjust dynamic
lrate = ReduceLROnPlateau(min_lr=0.00001)

answer.compile(optimizer='rmsprop', loss='categorical_crossentropy',
    metrics=['accuracy'])
# Note: you could use a Graph model to avoid repeat the input twice
answer.fit(
 [inputs_train, queries_train, inputs_train], answers_train,
 batch_size=32,
 nb_epoch=5000,
 validation_data=([inputs_test, queries_test, inputs_test], answers_test),
 callbacks=[checkpointer, lrate]
)

keras回调函数中的Tensorboard

keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, 
   write_graph=True, write_images=True)

tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
...
model.fit(...inputs and parameters..., callbacks=[tbCallBack])
tensorboard --logdir path_to_current_dir/Graph

或者

from keras.callbacks import TensorBoard

tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
       write_graph=True, write_images=False)
# define model
model.fit(X_train, Y_train,
   batch_size=batch_size,
   epochs=nb_epoch,
   validation_data=(X_test, Y_test),
   shuffle=True,
   callbacks=[tensorboard])

补充知识:Keras中的回调函数(callback)的使用与介绍

以前我在训练的时候,都是直接设定一个比较大的epoch,跑完所有的epoch之后再根据数据去调整模型与参数。这样做会比较耗时,例如说训练在某一个epoch开始已经过拟合了,后面继续训练意义就不大了。

在书上看到的callback函数很好的解决了这个问题,它能够监测训练过程中的loss或者acc这些指标,一旦观察到损失不再改善之后,就可以中止训练,节省时间。下面记录一下

介绍:

(选自《python深度学习》)

回调函数(callback)是在调用fit时传入模型的一个对象,它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。

部分回调函数:

1.ModelCheckpoint与EarlyStopping

监控目标若在指定轮数内不再改善,可利用EarlyStopping来中断训练。

可配合ModelCheckpoint使用,该回调函数可不断地保存模型,亦可以只保存某一epoch最佳性能模型

import keras
callbacks_list=[
 keras.callbacks.EarlyStopping(
  monitor='acc',#监控精度
  patience=5,#5轮内不改善就中止
),
 keras.callbacks.ModelCheckpoint(
  filepath='C:/apple/my_model.h5',#模型保存路径
  monitor='val_loss',#检测验证集损失值
  save_best_only=True#是否只保存最佳模型
 )
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

2.ReduceLROnPlateau回调函数

如果验证损失不再改善,可以使用该回调函数来降低学习率。

import keras
 
callbacks_list=[
 keras.callbacks.ReduceLROnPlateau(
  monitor='val_loss',#监控精度
  patienece=5, # 5轮内不改善就改变
  factor=0.1#学习率变为原来的0.1
)
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

以上这篇基于keras中的回调函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python在windows和linux下获得本机本地ip地址方法小结
Mar 20 Python
Python3通过Luhn算法快速验证信用卡卡号的方法
May 14 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
Jul 02 Python
Python自动化部署工具Fabric的简单上手指南
Apr 19 Python
Python利用递归和walk()遍历目录文件的方法示例
Jul 14 Python
python基础梳理(一)(推荐)
Apr 06 Python
python爬虫之快速对js内容进行破解
Jul 09 Python
Python Flask上下文管理机制实例解析
Mar 16 Python
Django使用rest_framework写出API
May 21 Python
解决pycharm中的run和debug失效无法点击运行
Jun 09 Python
python与c语言的语法有哪些不一样的
Sep 13 Python
pycharm 2020.2.4 pip install Flask 报错 Error:Non-zero exit code的问题
Dec 04 Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
浅谈Python协程
Jun 17 #Python
使用K.function()调试keras操作
Jun 17 #Python
哪些是python中web开发框架
Jun 17 #Python
python如何处理程序无法打开
Jun 16 #Python
python模块如何查看
Jun 16 #Python
You might like
用PHP发电子邮件
2006/10/09 PHP
php中debug_backtrace、debug_print_backtrace和匿名函数用法实例
2014/12/01 PHP
PHP的命令行命令使用指南
2015/08/18 PHP
PHP实现的限制IP投票程序IP来源分析
2016/05/04 PHP
PHP如何读取由JavaScript设置的Cookie
2017/03/22 PHP
PHP面向对象程序设计__tostring()和__invoke()用法分析
2019/06/12 PHP
一个无限级XML绑定跨框架菜单(For IE)
2007/01/27 Javascript
jQuery打印图片pdf、txt示例代码
2014/07/22 Javascript
Nodejs关于gzip/deflate压缩详解
2015/03/04 NodeJs
充分发挥Node.js程序性能的一些方法介绍
2015/06/23 Javascript
jQuery的框架介绍
2016/05/11 Javascript
JS实现经典的中国地区三级联动下拉菜单功能实例【测试可用】
2017/06/06 Javascript
ES6中Array.includes()函数的用法
2017/09/20 Javascript
vue单个组件实现无限层级多选菜单功能
2018/04/10 Javascript
webpack配置proxyTable时pathRewrite无效的解决方法
2018/12/13 Javascript
vue+php实现的微博留言功能示例
2019/03/16 Javascript
vue+Element中table表格实现可编辑(select下拉框)
2020/05/21 Javascript
Vue.js使用axios动态获取response里的data数据操作
2020/09/08 Javascript
[55:35]VGJ.S vs Mski Supermajor小组赛C组 BO3 第二场 6.3
2018/06/04 DOTA
python实现socket端口重定向示例
2014/02/10 Python
python求列表交集的方法汇总
2014/11/10 Python
Python实现两个list求交集,并集,差集的方法示例
2018/08/02 Python
如何查看Django ORM执行的SQL语句的实现
2020/04/20 Python
python中sympy库求常微分方程的用法
2020/04/28 Python
Python计算信息熵实例
2020/06/18 Python
MoviePy简介及Python视频剪辑自动化
2020/12/18 Python
canvas简易绘图的实现(海绵宝宝篇)
2018/07/04 HTML / CSS
秘鲁购物网站:Linio秘鲁
2017/04/07 全球购物
澳大利亚最大的在线美发和美容零售商之一:My Hair Care & Beauty
2019/08/24 全球购物
科研课题实施方案
2014/03/18 职场文书
完美的中文自荐信
2014/05/24 职场文书
文秘专业应届生求职信
2014/05/26 职场文书
企业精神口号
2014/06/11 职场文书
跟班学习心得体会(共6篇)
2016/01/23 职场文书
经典励志格言:每日一句,让你每天充满能量
2019/08/16 职场文书
使用vue判断当前环境是安卓还是IOS
2022/04/12 Vue.js