浅谈keras 模型用于预测时的注意事项


Posted in Python onJune 27, 2020

为什么训练误差比测试误差高很多?

一个Keras的模型有两个模式:训练模式测试模式一些正则机制,如Dropout,L1/L2正则项在测试模式下将不被启用。

另外,训练误差是训练数据每个batch的误差的平均。在训练过程中,每个epoch起始时的batch的误差要大一些,而后面的batch的误差要小一些。另一方面,每个epoch结束时计算的测试误差是由模型在epoch结束时的状态决定的,这时候的网络将产生较小的误差。

【Tips】可以通过定义回调函数将每个epoch的训练误差和测试误差并作图,如果训练误差曲线和测试误差曲线之间有很大的空隙,说明你的模型可能有过拟合的问题。当然,这个问题与Keras无关。

在keras中文文档中指出了这一误区,笔者认为产生这一问题的原因在于网络实现的机制。即dropout层有前向实现和反向实现两种方式,这就决定了概率p是在训练时候设置还是测试的时候进行设置

利用预训练的权值进行Fine tune时的注意事项:

不能把自己添加的层进行将随机初始化后直接连接到前面预训练后的网络层

in order to perform fine-tuning, all layers should start with properly trained weights: for instance you should not slap a randomly initialized fully-connected network on top of a pre-trained convolutional base. This is because the large gradient updates triggered by the randomly initialized weights would wreck the learned weights in the convolutional base. In our case this is why we first train the top-level classifier, and only then start fine-tuning convolutional weights alongside it.

we choose to only fine-tune the last convolutional block rather than the entire network in order to prevent overfitting, since the entire network would have a very large entropic capacity and thus a strong tendency to overfit. The features learned by low-level convolutional blocks are more general, less abstract than those found higher-up, so it is sensible to keep the first few blocks fixed (more general features) and only fine-tune the last one (more specialized features).

fine-tuning should be done with a very slow learning rate, and typically with the SGD optimizer rather than an adaptative learning rate optimizer such as RMSProp. This is to make sure that the magnitude of the updates stays very small, so as not to wreck the previously learned features.

补充知识:keras框架中用keras.models.Model做的时候预测数据不是标签的问题

我们发现,在用Sequential去搭建网络的时候,其中有predict和predict_classes两个预测函数,前一个是返回的精度,后面的是返回的具体标签。但是,在使用keras.models.Model去做的时候,就会发现,它只有一个predict函数,没有返回标签的predict_classes函数,所以,针对这个问题,我们将其改写。改写如下:

def my_predict_classes(predict_data):
  if predict_data.shape[-1] > 1:
    return predict_data.argmax(axis=-1)
  else:
    return (predict_data > 0.5).astype('int32')
 
# 这里省略网络搭建部分。。。。
 
model = Model(data_input, label_output)
model.compile(loss='categorical_crossentropy',
       optimizer=keras.optimizers.Nadam(lr=0.002),
       metrics=['accuracy'])
model.summary()
 
y_predict = model.predict(X_test)
y_pre = my_predict_classes(y_predict)

这样,y_pre就是具体的标签了。

以上这篇浅谈keras 模型用于预测时的注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python处理PDF及生成多层PDF实例代码
Apr 24 Python
python3实现UDP协议的服务器和客户端
Jun 14 Python
Django中利用filter与simple_tag为前端自定义函数的实现方法
Jun 15 Python
python实现Dijkstra静态寻路算法
Jan 17 Python
python实现字符串加密 生成唯一固定长度字符串
Mar 22 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
Python Web框架之Django框架文件上传功能详解
Aug 16 Python
python系统指定文件的查找只输出目录下所有文件及文件夹
Jan 19 Python
python实现控制台输出彩色字体
Apr 05 Python
jupyter notebook oepncv 显示一张图像的实现
Apr 24 Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 Python
Python WebSocket长连接心跳与短连接的示例
Nov 24 Python
python suds访问webservice服务实现
Jun 26 #Python
解析Python 偏函数用法全方位实现
Jun 26 #Python
Python如何优雅删除字符列表空字符及None元素
Jun 25 #Python
使用pytorch实现论文中的unet网络
Jun 24 #Python
python连接mysql有哪些方法
Jun 24 #Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 #Python
Python Tornado核心及相关原理详解
Jun 24 #Python
You might like
PHP4 与 MySQL 数据库操作函数详解
2006/10/09 PHP
php 服务器调试 Zend Debugger 的安装教程
2009/09/25 PHP
Zend Framework教程之Zend_Db_Table_Row用法实例分析
2016/03/21 PHP
php+MySql实现登录系统与输出浏览者信息功能
2016/07/01 PHP
PHP实现数组array转换成xml的方法
2016/07/19 PHP
Zend Framework入门教程之Zend_Db数据库操作详解
2016/12/08 PHP
利用php-cli和任务计划实现订单同步功能的方法
2017/05/03 PHP
跨浏览器的设置innerHTML方法
2006/09/18 Javascript
JS访问SWF的函数用法实例
2015/07/01 Javascript
javascript拖拽效果延伸学习
2016/04/04 Javascript
使用原生的javascript来实现轮播图
2017/02/24 Javascript
基于Node.js的WebSocket通信实现
2017/03/11 Javascript
bootstrap中的导航条实例代码详解
2019/05/20 Javascript
Nuxt.js nuxt-link与router-link的区别说明
2020/11/06 Javascript
Python数据结构之Array用法实例
2014/10/09 Python
编写同时兼容Python2.x与Python3.x版本的代码的几个示例
2015/03/30 Python
详解Django之auth模块(用户认证)
2018/04/17 Python
python将txt文档每行内容循环插入数据库的方法
2018/12/28 Python
详解利用OpenCV提取图像中的矩形区域(PPT屏幕等)
2019/07/01 Python
python logging 日志的级别调整方式
2020/02/21 Python
python ETL工具 pyetl
2020/06/07 Python
利用css3如何设置没有上下边的列表间隔线
2017/07/03 HTML / CSS
详解HTML5中垂直上下居中的解决方案
2017/12/20 HTML / CSS
HTML5 对各个标签的定义与规定:body的介绍
2012/06/21 HTML / CSS
zooplus波兰:在线宠物店
2019/07/21 全球购物
美国在线艺术商店:HandmadePiece
2020/11/06 全球购物
自我评价优秀范文分享
2013/11/30 职场文书
特色蛋糕店创业计划书
2014/01/28 职场文书
2014三八妇女节活动总结范文四篇
2014/03/09 职场文书
化学工程专业求职信
2014/08/10 职场文书
四风查摆问题及整改措施
2014/10/10 职场文书
学院党的群众路线教育实践活动第一阶段情况汇报
2014/10/25 职场文书
医生见习报告范文
2014/11/03 职场文书
2015年后勤工作总结范文
2015/04/08 职场文书
各类场合主持词开场白范文集锦
2019/08/16 职场文书
Python爬虫之爬取二手房信息
2021/04/27 Python