Keras 实现加载预训练模型并冻结网络的层


Posted in Python onJune 15, 2020

在解决一个任务时,我会选择加载预训练模型并逐步fine-tune。比如,分类任务中,优异的深度学习网络有很多。

ResNet, VGG, Xception等等... 并且这些模型参数已经在imagenet数据集中训练的很好了,可以直接拿过来用。

根据自己的任务,训练一下最后的分类层即可得到比较好的结果。此时,就需要“冻结”预训练模型的所有层,即这些层的权重永不会更新。

以Xception为例:

加载预训练模型:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))

include_top = False : 不包含顶层的3个全链接网络

weights : 加载预训练权重

随后,根据自己的分类任务加一层网络即可。

网络具体参数:

model.summary

得到两个网络层,第一层是xception层,第二层为分类层。

由于未冻结任何层,trainable params为:20, 811, 050

Keras 实现加载预训练模型并冻结网络的层

冻结网络层:

由于第一层为xception,不想更新xception层的参数,可以加以下代码:

model.layers[0].trainable = False

Keras 实现加载预训练模型并冻结网络的层

冻结预训练模型中的层

如果想冻结xception中的部分层,可以如下操作:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))
for i, layer in enumerate(model.layers[0].layers):
 if i > 115:
 layer.trainable = True
 else:
 layer.trainable = False
 print(i, layer.name, layer.trainable)

Keras 实现加载预训练模型并冻结网络的层

Keras 实现加载预训练模型并冻结网络的层

加载所有预训练模型的层

若想把xeption的所有层应用在训练自己的数据,并改变分类数。可以如下操作:

model = Sequential()
model.add(Xception(include_top=True, weights=None, classes=NUM_CLASS))

* 如果想指定classes,有两个条件:include_top:True, weights:None。否则无法指定classes

补充知识:如何利用预训练模型进行模型微调(如冻结某些层,不同层设置不同学习率等)

由于预训练模型权重和我们要训练的数据集存在一定的差异,且需要训练的数据集有大有小,所以进行模型微调、设置不同学习率就变得比较重要,下面主要分四种情况进行讨论,错误之处或者不足之处还请大佬们指正。

(1)待训练数据集较小,与预训练模型数据集相似度较高时。例如待训练数据集中数据存在于预训练模型中时,不需要重新训练模型,只需要修改最后一层输出层即可。

(2)待训练数据集较小,与预训练模型数据集相似度较小时。可以冻结模型的前k层,重新模型的后n-k层。冻结模型的前k层,用于弥补数据集较小的问题。

(3)待训练数据集较大,与预训练模型数据集相似度较大时。采用预训练模型会非常有效,保持模型结构不变和初始权重不变,对模型重新训练

(4)待训练数据集较大,与预训练模型数据集相似度较小时。采用预训练模型不会有太大的效果,可以使用预训练模型或者不使用预训练模型,然后进行重新训练。

以上这篇Keras 实现加载预训练模型并冻结网络的层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python如何让类支持比较运算
Mar 20 Python
python实现两个文件合并功能
Apr 01 Python
python数字图像处理实现直方图与均衡化
May 04 Python
Python实现计算字符串中出现次数最多的字符示例
Jan 21 Python
Scrapy框架爬取Boss直聘网Python职位信息的源码
Feb 22 Python
Python redis操作实例分析【连接、管道、发布和订阅等】
May 16 Python
使用python打印十行杨辉三角过程详解
Jul 10 Python
python获取array中指定元素的示例
Nov 26 Python
python实现Pyecharts实现动态地图(Map、Geo)
Mar 25 Python
Python 如何创建一个线程池
Jul 28 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 Python
python中的测试框架
Nov 13 Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
Python smtp邮件发送模块用法教程
Jun 15 #Python
pandas数据处理之绘图的实现
Jun 15 #Python
keras中的loss、optimizer、metrics用法
Jun 15 #Python
使用keras实现Precise, Recall, F1-socre方式
Jun 15 #Python
基于python和flask实现http接口过程解析
Jun 15 #Python
基于nexus3配置Python仓库过程详解
Jun 15 #Python
You might like
虫族 Zerg 热键控制
2020/03/14 星际争霸
自己在做项目过程中学到的PHP知识收集
2012/08/20 PHP
wordpress自定义url参数实现路由功能的代码示例
2013/11/28 PHP
PHP的CURL方法curl_setopt()函数案例介绍(抓取网页,POST数据)
2016/12/14 PHP
Yii2框架实现登录、退出及自动登录功能的方法详解
2017/10/24 PHP
php中字符串和整数比较的操作方法
2019/06/06 PHP
use jscript List Installed Software
2007/06/11 Javascript
jQuery 隔行换色 支持键盘上下键,按Enter选定值
2009/08/02 Javascript
js关闭当前页面(窗口)的几种方式总结
2013/03/05 Javascript
js复制网页内容并兼容各主流浏览器的代码
2013/12/17 Javascript
JavaScript数值数组排序示例分享
2014/05/27 Javascript
javascript随机之洗牌算法深入分析
2014/06/07 Javascript
js实现获取鼠标当前的位置
2016/12/14 Javascript
微信小程序 radio单选框组件详解及实例代码
2017/01/10 Javascript
JS批量替换内容中关键词为超链接
2017/02/20 Javascript
jquery中each循环的简单回滚操作
2017/05/05 jQuery
Vue.js上下滚动加载组件的实例代码
2017/07/17 Javascript
Vue组件之极简的地址选择器的实现
2018/05/31 Javascript
jQuery实现小火箭返回顶部特效
2020/02/03 jQuery
[04:00]黄浦江畔,再会英雄——完美世界DOTA2 TI9应援视频
2019/07/31 DOTA
python3.7 sys模块的具体使用
2019/07/22 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
2020/04/30 Python
Tensorflow全局设置可见GPU编号操作
2020/06/30 Python
Python unittest装饰器实现原理及代码
2020/09/08 Python
美国领先的医疗警报服务:Philips Lifeline
2018/03/12 全球购物
地理科学专业毕业生求职信
2013/10/15 职场文书
出纳的岗位职责
2013/11/09 职场文书
高校教师思想汇报
2014/01/11 职场文书
电脑饰品店的创业计划书
2014/01/21 职场文书
群众路线教育实践活动方案
2014/02/02 职场文书
安全承诺书范文
2014/03/26 职场文书
药剂专业自荐信范文
2014/04/16 职场文书
80后婚前协议书范本
2014/10/24 职场文书
2015年社区教育工作总结
2015/05/13 职场文书
Django展示可视化图表的多种方式
2021/04/08 Python
java设计模式--三种工厂模式详解
2021/07/21 Java/Android