Kears 使用:通过回调函数保存最佳准确率下的模型操作


Posted in Python onJune 17, 2020

1:首先,我给我的MixTest文件夹里面分好了类的图片进行重命名(因为分类的时候没有注意导致命名有点不好)

def load_data(path):
 Rename the picture [a tool]
 for eachone in os.listdir(path):
  newname = eachone[7:]
  os.rename(path+"\\"+eachone,path+"\\"+newname)

但是需要注意的是:我们按照类重命名了以后,系统其实会按照图片来排序。这个时候你会看到同一个类的被排序在了一块。这个时候你不要慌张,其实这个顺序是完全不用担心的。我们只是需要得到网络对某一个图片的输出是怎么样的判断标签。这个顺序对网络计算其权重完全是没有任何的影响的

2:我在Keras中使用InceptionV3这个模型进行训练,训练模型的过程啥的我在这里就不详细说了(毕竟这个东西有点像随记那样的东西)

我们在Keras的模型里面是可以通过

H.history["val_acc"]
H.history["val_loss"]

来的得到历史交叉准确率这样的指标

3:

对于每个epoch,我们都会计算一次val_acc和val_loss,我很希望保留下我最高的val_acc的模型,那该怎么办呢?

这个时候我就会使用keras的callback函数

H = model.fit_generator(train_datagen.flow(X_train, Y_train, batch_size=batchsize),
  validation_data=(X_test, Y_test), steps_per_epoch=(X_train.shape[0]) // batchsize,
  epochs=epoch, verbose=1, callbacks=[tb(log_dir='E:\John\log'),
           save_function])

上面的参数先查查文档把。这里我就说说我的callbacks

callbacks=[tb(log_dir = 'E\John\log')]

这个是使用tensorboard来可视化训练过程的,后面是tensorboard的log输出文件夹的路径,在网络训练的时候,相对应的训练的状态就会保存在这个文件夹下

打开终端,输入

tensorboard --log_dir <your name of the log dir> --port <the port for tensorboard>

然后输入终端指示的网址在浏览器中打开,就可以在tensorboard中看到你训练的状态了

save_function:

这是一个类的实例化:

class Save(keras.callbacks.Callback):
 def __init__(self):
  self.max_acc = 0.0
 
 def on_epoch_begin(self, epoch, logs=None):
  pass
 
 def on_epoch_end(self, epoch, logs=None):
  self.val_acc = logs["val_acc"]
  if epoch != 0:
   if self.val_acc > self.max_acc and self.val_acc > 0.8:
    model.save("kears_model_"+str(epoch)+ "_acc="+str(self.val_acc)+".h5")
    self.max_acc = self.val_acc
 
save_function = Save()

这里继承了kears.callbacks.Callback

看看on_epoch_end:

在这个epoch结束的时候,我会得到它的val_acc

当这个val_acc为历史最大值的时候,我就保存这个模型

在训练结束以后,你就挑出acc最大的就好啦(当然,你可以命名为一样的,最后的到的模型就不用挑了,直接就是acc最大的模型了)

补充知识:Keras回调函数Callbacks使用详解及训练过程可视化

介绍

内容参考了keras中文文档

回调函数Callbacks

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。

【Tips】虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

类属性:

params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy']。

在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

from keras.callbacks import Callback

功能

History(训练可视化

keras.callbacks.History()

该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图,代码样例如下:

history=model.fit(X_train, Y_train, validation_data=(X_test,Y_test),
 batch_size=16, epochs=20)
##或者
#history=model.fit(X_train,y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test)) 
fig1, ax_acc = plt.subplots()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Model - Accuracy')
plt.legend(['Training', 'Validation'], loc='lower right')
plt.show()

fig2, ax_loss = plt.subplots()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model- Loss')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()

EarlyStopping

keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')

当监测值不再改善时,该回调函数将中止训练

参数

monitor:需要监视的量

patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

verbose:信息展示模式

verbose = 0 为不在标准输出流输出日志信息

verbose = 1 为输出进度条记录

verbose = 2 为每个epoch输出一行记录

默认为 1

mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

ModelCheckpoint

该回调函数将在每个epoch后保存模型到filepath

filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_end的logs关键字所填入

例如,filepath若为weights.{epoch:02d-{val_loss:.2f}}.hdf5,则会生成对应epoch和验证集loss的多个文件。

参数

filename:字符串,保存模型的路径

monitor:需要监视的值

verbose:信息展示模式,0或1

save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

mode:‘auto',‘min',‘max'之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

period:CheckPoint之间的间隔的epoch数

Callbacks中可以同时使用多个以上两个功能,举例如下

callbacks = [EarlyStopping(monitor='val_loss', patience=8),
    ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]
history=model.fit(X_train, y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test))

在样例中,EarlyStopping设置衡量标注为val_loss,如果其连续4次没有下降就提前停止 ,ModelCheckpoint设置衡量标准为val_loss,设置只保存最佳模型,保存路径为best——model.h5

ReduceLROnPlateau

keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)

当评价指标不在提升时,减少学习率

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。该回调函数检测指标的情况,如果在patience个epoch中看不到模型性能提升,则减少学习率

参数

monitor:被监测的量 factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少

patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发

mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。

epsilon:阈值,用来确定是否进入检测值的“平原区”

cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作 min_lr:学习率的下限

使用样例如下:

callbacks_test = [
 keras.callbacks.ReduceLROnPlateau(
 #以val_loss作为衡量标准
 monitor='val_loss',
 # 学习率乘以factor
 factor=0.1,
 # It will get triggered after the validation loss has stopped improving
 # 当被检测的衡量标准经过几次没有改善后就减小学习率
 patience=10,
 )
 ]
 model.fit(x, y,epochs=20,batch_size=16,
  callbacks=callbacks_test,
 validation_data=(x_val, y_val))

CSVLogger

keras.callbacks.CSVLogger(filename, separator=',', append=False)

将epoch的训练结果保存在csv文件中,支持所有可被转换为string的值,包括1D的可迭代数值如np.ndarray.

参数

fiename:保存的csv文件名,如run/log.csv

separator:字符串,csv分隔符

append:默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件

以上这篇Kears 使用:通过回调函数保存最佳准确率下的模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现爬取逐浪小说的方法
Jul 07 Python
Python实现数通设备端口使用情况监控实例
Jul 15 Python
利用python批量修改word文件名的方法示例
Oct 17 Python
python编写微信远程控制电脑的程序
Jan 05 Python
python3在同一行内输入n个数并用列表保存的例子
Jul 20 Python
Python安装selenium包详细过程
Jul 23 Python
Python中typing模块与类型注解的使用方法
Aug 05 Python
Python导入模块包原理及相关注意事项
Mar 25 Python
基于python实现计算两组数据P值
Jul 10 Python
python利用xlsxwriter模块 操作 Excel
Oct 14 Python
python time()的实例用法
Nov 03 Python
python通用数据库操作工具 pydbclib的使用简介
Dec 21 Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
浅谈Python协程
Jun 17 #Python
You might like
PHP UTF8编码内的繁简转换类
2009/07/20 PHP
php阻止页面后退的方法分享
2014/02/17 PHP
CentOS 上搭建 PHP7 开发测试环境
2017/02/26 PHP
PHP标准库(PHP SPL)详解
2019/03/16 PHP
优化网页之快速的呈现我们的网页
2007/06/29 Javascript
js控制CSS样式属性语法对照表
2012/12/11 Javascript
window.open关于浏览器拦截问题分析及解决方法
2013/02/05 Javascript
模拟用户点击弹出新页面不会被浏览器拦截
2014/04/08 Javascript
Javascript 拖拽雏形(逐行分析代码,让你轻松了拖拽的原理)
2015/01/23 Javascript
常用原生JS兼容性写法汇总
2016/04/27 Javascript
基于BootStrap Metronic开发框架经验小结【三】下拉列表Select2插件的使用
2016/05/12 Javascript
浅析jquery数组删除指定元素的方法:grep()
2016/05/19 Javascript
BootStrap智能表单实战系列(六)表单编辑页面的数据绑定
2016/06/13 Javascript
jquery.qtip提示信息插件用法简单实例
2016/06/17 Javascript
简单实现IONIC购物车功能
2017/01/10 Javascript
微信小程序tabbar不显示解决办法
2017/06/08 Javascript
微信小程序wx.previewImage预览图片实例详解
2017/12/07 Javascript
vue 关闭浏览器窗口的时候,清空localStorage的数据示例
2019/11/06 Javascript
在VUE中使用lodash的debounce和throttle操作
2020/11/09 Javascript
详解ES6 中的Object.assign()的用法实例代码
2021/01/11 Javascript
python实现html转ubb代码(html2ubb)
2014/07/03 Python
python 时间戳与格式化时间的转化实现代码
2016/03/23 Python
关于numpy中np.nonzero()函数用法的详解
2017/02/07 Python
Python中多个数组行合并及列合并的方法总结
2018/04/12 Python
Python 一键获取百度网盘提取码的方法
2019/08/01 Python
Python操作SQLite/MySQL/LMDB数据库的方法
2019/11/07 Python
python 截取XML中bndbox的坐标中的图像,另存为jpg的实例
2020/03/10 Python
Python面向对象程序设计之私有变量,私有方法原理与用法分析
2020/03/23 Python
python 一维二维插值实例
2020/04/22 Python
大学活动邀请函
2014/01/28 职场文书
年度献血先进个人事迹材料
2014/02/14 职场文书
2014年建筑工作总结
2014/11/26 职场文书
财务经理岗位职责范本
2015/04/08 职场文书
微信小程序用户授权最佳实践指南
2021/05/08 Javascript
Python爬取英雄联盟MSI直播间弹幕并生成词云图
2021/06/01 Python
CSS实现两列布局的N种方法
2021/08/02 HTML / CSS