基于keras中的回调函数用法说明


Posted in Python onJune 17, 2020

keras训练

fit(
 self, 
 x, 
 y, 
 batch_size=32, 
 nb_epoch=10, 
 verbose=1, 
 callbacks=[], 
 validation_split=0.0, 
 validation_data=None, 
 shuffle=True, 
 class_weight=None, 
 sample_weight=None
)

1. x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。

2. y:标签,numpy array。如果模型有多个输出,可以传入一个numpy array的list。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。

3. batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。

4. nb_epoch:整数,训练的轮数,训练数据将会被遍历nb_epoch次。Keras中nb开头的变量均为"number of"的意思

5. verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

6. callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数

7. validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

8. validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。

9. shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。

10. class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。

11. sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'。

fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。

保存模型结构、训练出来的权重、及优化器状态

keras 的 callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

keras.callbacks.ModelCheckpoint(
 filepath, 
 monitor='val_loss', 
 verbose=0, 
 save_best_only=False, 
 save_weights_only=False, 
 mode='auto', 
 period=1
)

1. filename:字符串,保存模型的路径

2. monitor:需要监视的值

3. verbose:信息展示模式,0或1

4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

5. mode:‘auto',‘min',‘max'之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

7. period:CheckPoint之间的间隔的epoch数

当验证损失不再继续降低时,如何中断训练?当监测值不再改善时中止训练

用EarlyStopping回调函数

from keras.callbacksimport EarlyStopping 

keras.callbacks.EarlyStopping(
 monitor='val_loss', 
 patience=0, 
 verbose=0, 
 mode='auto'
)

model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])

1. monitor:需要监视的量

2. patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

3. verbose:信息展示模式

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

学习率动态调整1

keras.callbacks.LearningRateScheduler(schedule)

schedule:函数,该函数以epoch号为参数(从0算起的整数),返回一个新学习率(浮点数)

也可以让keras自动调整学习率

keras.callbacks.ReduceLROnPlateau(
 monitor='val_loss', 
 factor=0.1, 
 patience=10, 
 verbose=0, 
 mode='auto', 
 epsilon=0.0001, 
 cooldown=0, 
 min_lr=0
)

1. monitor:被监测的量

2. factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少

3. patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发

4. mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。

5. epsilon:阈值,用来确定是否进入检测值的“平原区”

6. cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作

7. min_lr:学习率的下限

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果

学习率动态2

def step_decay(epoch):
 initial_lrate = 0.01
 drop = 0.5
 epochs_drop = 10.0
 lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
 return lrate
lrate = LearningRateScheduler(step_decay)
sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])

如何记录每一次epoch的训练/验证损失/准确度?

Model.fit函数会返回一个 History 回调,该回调有一个属性history包含一个封装有连续损失/准确的lists。代码如下:

hist = model.fit(X, y,validation_split=0.2)
print(hist.history)

Keras输出的loss,val这些值如何保存到文本中去

Keras中的fit函数会返回一个History对象,它的History.history属性会把之前的那些值全保存在里面,如果有验证集的话,也包含了验证集的这些指标变化情况,具体写法

hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
with open('log_sgd_big_32.txt','w') as f:
 f.write(str(hist.history))

示例,多个回调函数用逗号隔开

# checkpoint
checkpointer = ModelCheckpoint(filepath="./checkpoint.hdf5", verbose=1)
# learning rate adjust dynamic
lrate = ReduceLROnPlateau(min_lr=0.00001)

answer.compile(optimizer='rmsprop', loss='categorical_crossentropy',
    metrics=['accuracy'])
# Note: you could use a Graph model to avoid repeat the input twice
answer.fit(
 [inputs_train, queries_train, inputs_train], answers_train,
 batch_size=32,
 nb_epoch=5000,
 validation_data=([inputs_test, queries_test, inputs_test], answers_test),
 callbacks=[checkpointer, lrate]
)

keras回调函数中的Tensorboard

keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, 
   write_graph=True, write_images=True)

tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
...
model.fit(...inputs and parameters..., callbacks=[tbCallBack])
tensorboard --logdir path_to_current_dir/Graph

或者

from keras.callbacks import TensorBoard

tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
       write_graph=True, write_images=False)
# define model
model.fit(X_train, Y_train,
   batch_size=batch_size,
   epochs=nb_epoch,
   validation_data=(X_test, Y_test),
   shuffle=True,
   callbacks=[tensorboard])

补充知识:Keras中的回调函数(callback)的使用与介绍

以前我在训练的时候,都是直接设定一个比较大的epoch,跑完所有的epoch之后再根据数据去调整模型与参数。这样做会比较耗时,例如说训练在某一个epoch开始已经过拟合了,后面继续训练意义就不大了。

在书上看到的callback函数很好的解决了这个问题,它能够监测训练过程中的loss或者acc这些指标,一旦观察到损失不再改善之后,就可以中止训练,节省时间。下面记录一下

介绍:

(选自《python深度学习》)

回调函数(callback)是在调用fit时传入模型的一个对象,它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。

部分回调函数:

1.ModelCheckpoint与EarlyStopping

监控目标若在指定轮数内不再改善,可利用EarlyStopping来中断训练。

可配合ModelCheckpoint使用,该回调函数可不断地保存模型,亦可以只保存某一epoch最佳性能模型

import keras
callbacks_list=[
 keras.callbacks.EarlyStopping(
  monitor='acc',#监控精度
  patience=5,#5轮内不改善就中止
),
 keras.callbacks.ModelCheckpoint(
  filepath='C:/apple/my_model.h5',#模型保存路径
  monitor='val_loss',#检测验证集损失值
  save_best_only=True#是否只保存最佳模型
 )
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

2.ReduceLROnPlateau回调函数

如果验证损失不再改善,可以使用该回调函数来降低学习率。

import keras
 
callbacks_list=[
 keras.callbacks.ReduceLROnPlateau(
  monitor='val_loss',#监控精度
  patienece=5, # 5轮内不改善就改变
  factor=0.1#学习率变为原来的0.1
)
]
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['acc'])
model.fit(x,y,
   epochs=10,
   batch_size=32,
   callbacks=callbacks_list,#在这里放入callback函数
   validation_data=(x_val,y_val)
 )

以上这篇基于keras中的回调函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python datetime时间格式化去掉前导0
Jul 31 Python
整理Python 常用string函数(收藏)
May 30 Python
python基于pyDes库实现des加密的方法
Apr 29 Python
Python实现的概率分布运算操作示例
Aug 14 Python
JSONLINT:python的json数据验证库实例解析
Nov 28 Python
python的socket编程入门
Jan 29 Python
Python实现的读写json文件功能示例
Jun 05 Python
Python3.6简单的操作Mysql数据库的三个实例
Oct 17 Python
解决python有时候import不了当前的包问题
Aug 28 Python
PyTorch中permute的用法详解
Dec 30 Python
Python对Tornado请求与响应的数据处理
Feb 12 Python
Python中关于logging模块的学习笔记
Jun 03 Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
浅谈Python协程
Jun 17 #Python
使用K.function()调试keras操作
Jun 17 #Python
哪些是python中web开发框架
Jun 17 #Python
python如何处理程序无法打开
Jun 16 #Python
python模块如何查看
Jun 16 #Python
You might like
PHP高级对象构建 工厂模式的使用
2012/02/05 PHP
PHP获取MSN好友列表类的实现代码
2013/06/23 PHP
PHP字符串比较函数strcmp()和strcasecmp()使用总结
2014/11/19 PHP
PHP实现向关联数组指定的Key之前插入元素的方法
2017/06/06 PHP
JavaScript 事件对象的实现
2009/07/13 Javascript
客户端限制只能上传jpg格式图片的js代码
2010/12/09 Javascript
jQuery写fadeTo示例代码
2014/02/21 Javascript
jQuery使用after()方法在元素后面添加多项内容的方法
2015/03/26 Javascript
javascript中setAttribute()函数使用方法及兼容性
2015/07/19 Javascript
jquery实现定时自动轮播特效
2015/12/10 Javascript
js数组常用操作方法小结(增加,删除,合并,分割等)
2016/08/02 Javascript
vue.js初学入门教程(1)
2016/11/03 Javascript
package.json文件配置详解
2017/06/15 Javascript
详解webpack自定义loader初探
2018/08/29 Javascript
vue interceptor 使用教程实例详解
2018/09/13 Javascript
Next.js实现react服务器端渲染的方法示例
2019/01/06 Javascript
jquery操作checkbox的常用方法总结【附测试源码下载】
2019/06/10 jQuery
JavaScript提升机制Hoisting详解
2019/10/23 Javascript
vue.js自定义组件实现v-model双向数据绑定的示例代码
2020/01/08 Javascript
使用SAE部署Python运行环境的教程
2015/05/05 Python
python决策树之C4.5算法详解
2017/12/20 Python
Python之两种模式的生产者消费者模型详解
2018/10/26 Python
Python队列RabbitMQ 使用方法实例记录
2019/08/05 Python
numpy.transpose()实现数组的转置例子
2019/12/02 Python
Canvas波浪花环的示例代码
2020/08/21 HTML / CSS
UGG雪地靴荷兰官网:UGG荷兰
2016/09/09 全球购物
商务英语专业自荐信
2013/10/14 职场文书
董事长职责范文
2013/11/08 职场文书
酒店副总岗位职责
2013/12/24 职场文书
应聘医药销售自荐书范文
2014/02/08 职场文书
护士个人年终总结
2015/02/13 职场文书
排球赛新闻稿
2015/07/17 职场文书
中学音乐课教学反思
2016/02/18 职场文书
数据库的高级查询六:表连接查询:外连接(左外连接,右外连接,UNION关键字,连接中ON与WHERE的不同)
2021/04/05 MySQL
MySQL InnoDB ReplicaSet(副本集)简单介绍
2021/04/24 MySQL
Redis如何实现验证码发送 以及限制每日发送次数
2022/04/18 Redis