keras 回调函数Callbacks 断点ModelCheckpoint教程


Posted in Python onJune 18, 2020

整理自keras:https://keras-cn.readthedocs.io/en/latest/other/callbacks/

回调函数Callbacks

回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。

Callback

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

类属性

params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=['accuracy']。

在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

ModelCheckpoint

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

该回调函数将在每个epoch后保存模型到filepath

filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。

参数:

filepath: 字符串,保存模型的路径。

monitor: 被监测的数据。val_acc或这val_loss

verbose: 详细信息模式,0 或者 1 。0为不打印输出信息,1打印

save_best_only: 如果 save_best_only=True, 将只保存在验证集上性能最好的模型

mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。

save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。

period: 每个检查点之间的间隔(训练轮数)。

代码实现过程:

① 从keras.callbacks导入ModelCheckpoint类

from keras.callbacks import ModelCheckpoint

② 在训练阶段的model.compile之后加入下列代码实现每一次epoch(period=1)保存最好的参数

checkpoint = ModelCheckpoint(filepath,
monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)

③ 在训练阶段的model.fit之前加载先前保存的参数

if os.path.exists(filepath):
 model.load_weights(filepath)
 # 若成功加载前面保存的参数,输出下列信息
 print("checkpoint_loaded")

④ 在model.fit添加callbacks=[checkpoint]实现回调

model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
 steps_per_epoch=max(1, num_train//batch_size),
 validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
 validation_steps=max(1, num_val//batch_size),
 epochs=3,
 initial_epoch=0,
 callbacks=[checkpoint])

补充知识:keras之多输入多输出(多任务)模型

keras多输入多输出模型,以keras官网的demo为例,分析keras多输入多输出的适用。

主要输入(main_input): 新闻标题本身,即一系列词语。

辅助输入(aux_input): 接受额外的数据,例如新闻标题的发布时间等。

该模型将通过两个损失函数进行监督学习。

较早地在模型中使用主损失函数,是深度学习模型的一个良好正则方法。

完整过程图示如下:

keras 回调函数Callbacks 断点ModelCheckpoint教程

其中,红圈中的操作为将辅助数据与LSTM层的输出连接起来,输入到模型中。

代码实现:

import keras
from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
 
# 定义网络模型 
# 标题输入:接收一个含有 100 个整数的序列,每个整数在 1 到 10000 之间
# 注意我们可以通过传递一个 `name` 参数来命名任何层
main_input = Input(shape=(100,), dtype='int32', name='main_input')
 
# Embedding 层将输入序列编码为一个稠密向量的序列,每个向量维度为 512
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)
 
# LSTM 层把向量序列转换成单个向量,它包含整个序列的上下文信息
lstm_out = LSTM(32)(x)
 
# 在这里我们添加辅助损失,使得即使在模型主损失很高的情况下,LSTM层和Embedding层都能被平稳地训练
auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
 
# 此时,我们将辅助输入数据与LSTM层的输出连接起来,输入到模型中
auxiliary_input = Input(shape=(5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, auxiliary_output])
 
# 再添加剩余的层
# 堆叠多个全连接网络层
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
 
# 最后添加主要的逻辑回归层
main_output = Dense(1, activation='sigmoid', name='main_output')(x)
 
# 定义这个具有两个输入和输出的模型
model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])
 
# 编译模型时候分配损失函数权重:编译模型的时候,给 辅助损失 分配一个0.2的权重
model.compile(optimizer='rmsprop', loss='binary_crossentropy', loss_weights=[1., 0.2])
 
# 训练模型:我们可以通过传递输入数组和目标数组的列表来训练模型
model.fit([headline_data, additional_data], [labels, labels], epochs=50, batch_size=32)
 
# 另外一种利用字典的编译、训练方式
# 由于输入和输出均被命名了(在定义时传递了一个 name 参数),我们也可以通过以下方式编译模型
model.compile(optimizer='rmsprop',
    loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
    loss_weights={'main_output': 1., 'aux_output': 0.2})
# 然后使用以下方式训练:
model.fit({'main_input': headline_data, 'aux_input': additional_data},
   {'main_output': labels, 'aux_output': labels},
   epochs=50, batch_size=32)

相关参考:https://keras.io/zh/getting-started/functional-api-guide/

以上这篇keras 回调函数Callbacks 断点ModelCheckpoint教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python对html代码进行escape编码的方法
May 04 Python
在Mac OS上搭建Python的开发环境
Dec 24 Python
详解pyqt5 动画在QThread线程中无法运行问题
May 05 Python
python numpy 部分排序 寻找最大的前几个数的方法
Jun 27 Python
浅谈django的render函数的参数问题
Oct 16 Python
python 阶乘累加和的实例
Feb 01 Python
我就是这样学习Python中的列表
Jun 02 Python
浅析Python+OpenCV使用摄像头追踪人脸面部血液变化实现脉搏评估
Oct 17 Python
Python 一行代码能实现丧心病狂的功能
Jan 18 Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 Python
python 装饰器的实际作用有哪些
Sep 07 Python
Python虚拟环境的创建和使用详解
Sep 07 Python
浅谈keras.callbacks设置模型保存策略
Jun 18 #Python
用python实现名片管理系统
Jun 18 #Python
Python 为什么推荐蛇形命名法原因浅析
Jun 18 #Python
python退出循环的方法
Jun 18 #Python
keras实现多GPU或指定GPU的使用介绍
Jun 17 #Python
Python字符串格式化常用手段及注意事项
Jun 17 #Python
python代码区分大小写吗
Jun 17 #Python
You might like
php 采集书并合成txt格式的实现代码
2009/03/01 PHP
PHP临时文件的安全性分析
2014/07/04 PHP
微信支付开发交易通知实例
2016/07/12 PHP
使用PHP连接多种数据库的实现代码(mysql,access,sqlserver,Oracle)
2016/12/21 PHP
PHP实现的数独求解问题示例
2017/04/18 PHP
javascript fullscreen全屏实现代码
2009/04/09 Javascript
node.js中的fs.stat方法使用说明
2014/12/16 Javascript
JavaScript实现页面5秒后自动跳转的方法
2015/04/16 Javascript
JQuery简单实现锚点链接的平滑滚动
2015/05/03 Javascript
javascript中JSON.parse()与eval()解析json的区别
2016/05/19 Javascript
详解vue-cli + webpack 多页面实例应用
2017/04/25 Javascript
Javascript(es2016) import和require用法和区别详解
2017/08/11 Javascript
JS实现点击下拉菜单把选择的内容同步到input输入框内的实例
2018/01/23 Javascript
JavaScript实现写入文件到本地的方法【基于FileSaver.js插件】
2018/03/15 Javascript
使用Angular CLI生成路由的方法
2018/03/24 Javascript
如何手动实现es5中的bind方法详解
2018/12/07 Javascript
vue 点击展开显示更多(点击收起部分隐藏)
2019/04/09 Javascript
jQuery 选择器用法实例分析【prev + next】
2020/05/22 jQuery
在vue-cli3中使用axios获取本地json操作
2020/07/30 Javascript
Python SQLite3数据库操作类分享
2014/06/10 Python
自动化Nginx服务器的反向代理的配置方法
2015/06/28 Python
python中print的不换行即时输出的快速解决方法
2016/07/20 Python
使用python 爬虫抓站的一些技巧总结
2018/01/10 Python
Python 查看文件的读写权限方法
2018/01/23 Python
python素数筛选法浅析
2018/03/19 Python
Pycharm之快速定位到某行快捷键的方法
2019/01/20 Python
Python sep参数使用方法详解
2020/02/12 Python
英国领先的在线高尔夫商店:Gamola Golf
2019/11/16 全球购物
社区工作者演讲稿
2014/05/23 职场文书
银行授权委托书样本
2014/10/13 职场文书
青年教师个人总结
2015/02/11 职场文书
使用canvas实现雪花飘动效果的示例代码
2021/03/30 HTML / CSS
golang json数组拼接的实例
2021/04/28 Golang
Java Spring 控制反转(IOC)容器详解
2021/10/05 Java/Android
如何利用python实现列表嵌套字典取值
2022/06/10 Python
springboot创建的web项目整合Quartz框架的项目实践
2022/06/21 Java/Android