Keras 实现加载预训练模型并冻结网络的层


Posted in Python onJune 15, 2020

在解决一个任务时,我会选择加载预训练模型并逐步fine-tune。比如,分类任务中,优异的深度学习网络有很多。

ResNet, VGG, Xception等等... 并且这些模型参数已经在imagenet数据集中训练的很好了,可以直接拿过来用。

根据自己的任务,训练一下最后的分类层即可得到比较好的结果。此时,就需要“冻结”预训练模型的所有层,即这些层的权重永不会更新。

以Xception为例:

加载预训练模型:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))

include_top = False : 不包含顶层的3个全链接网络

weights : 加载预训练权重

随后,根据自己的分类任务加一层网络即可。

网络具体参数:

model.summary

得到两个网络层,第一层是xception层,第二层为分类层。

由于未冻结任何层,trainable params为:20, 811, 050

Keras 实现加载预训练模型并冻结网络的层

冻结网络层:

由于第一层为xception,不想更新xception层的参数,可以加以下代码:

model.layers[0].trainable = False

Keras 实现加载预训练模型并冻结网络的层

冻结预训练模型中的层

如果想冻结xception中的部分层,可以如下操作:

from tensorflow.python.keras.applications import Xception
model = Sequential()
model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(NUM_CLASS, activation='softmax'))
for i, layer in enumerate(model.layers[0].layers):
 if i > 115:
 layer.trainable = True
 else:
 layer.trainable = False
 print(i, layer.name, layer.trainable)

Keras 实现加载预训练模型并冻结网络的层

Keras 实现加载预训练模型并冻结网络的层

加载所有预训练模型的层

若想把xeption的所有层应用在训练自己的数据,并改变分类数。可以如下操作:

model = Sequential()
model.add(Xception(include_top=True, weights=None, classes=NUM_CLASS))

* 如果想指定classes,有两个条件:include_top:True, weights:None。否则无法指定classes

补充知识:如何利用预训练模型进行模型微调(如冻结某些层,不同层设置不同学习率等)

由于预训练模型权重和我们要训练的数据集存在一定的差异,且需要训练的数据集有大有小,所以进行模型微调、设置不同学习率就变得比较重要,下面主要分四种情况进行讨论,错误之处或者不足之处还请大佬们指正。

(1)待训练数据集较小,与预训练模型数据集相似度较高时。例如待训练数据集中数据存在于预训练模型中时,不需要重新训练模型,只需要修改最后一层输出层即可。

(2)待训练数据集较小,与预训练模型数据集相似度较小时。可以冻结模型的前k层,重新模型的后n-k层。冻结模型的前k层,用于弥补数据集较小的问题。

(3)待训练数据集较大,与预训练模型数据集相似度较大时。采用预训练模型会非常有效,保持模型结构不变和初始权重不变,对模型重新训练

(4)待训练数据集较大,与预训练模型数据集相似度较小时。采用预训练模型不会有太大的效果,可以使用预训练模型或者不使用预训练模型,然后进行重新训练。

以上这篇Keras 实现加载预训练模型并冻结网络的层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之PyMongo使用总结
May 26 Python
python爬取足球直播吧五大联赛积分榜
Jun 13 Python
Python提取频域特征知识点浅析
Mar 04 Python
Python 3.8中实现functools.cached_property功能
May 29 Python
简单了解python高阶函数map/reduce
Jun 28 Python
Django中celery执行任务结果的保存方法
Jul 12 Python
运用PyTorch动手搭建一个共享单车预测器
Aug 06 Python
Pandas时间序列重采样(resample)方法中closed、label的作用详解
Dec 10 Python
Jupyter notebook 远程配置及SSL加密教程
Apr 14 Python
学会迭代器设计模式,帮你大幅提升python性能
Jan 03 Python
20行代码教你用python给证件照换底色的方法示例
Feb 05 Python
Python爬取酷狗MP3音频的步骤
Feb 26 Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
Python smtp邮件发送模块用法教程
Jun 15 #Python
pandas数据处理之绘图的实现
Jun 15 #Python
keras中的loss、optimizer、metrics用法
Jun 15 #Python
使用keras实现Precise, Recall, F1-socre方式
Jun 15 #Python
基于python和flask实现http接口过程解析
Jun 15 #Python
基于nexus3配置Python仓库过程详解
Jun 15 #Python
You might like
利用PHP实现与ASP Banner组件相似的类
2006/10/09 PHP
PHP之浮点数计算比较以及取整数不准确的解决办法
2015/07/29 PHP
轻松掌握php设计模式之访问者模式
2016/09/23 PHP
php smtp实现发送邮件功能
2017/06/22 PHP
PHP实现基于PDO扩展连接PostgreSQL对象关系数据库示例
2018/03/31 PHP
thinkphp框架类库扩展操作示例
2019/11/26 PHP
ImageFlow可鼠标控制图片滚动
2008/01/30 Javascript
基于jQuery的固定表格头部的代码(IE6,7,8测试通过)
2010/05/18 Javascript
javascript中字符串拼接需注意的问题
2010/07/13 Javascript
yepnope.js 异步加载资源文件
2011/09/08 Javascript
基于jquery的DIV随滚动条滚动而滚动的代码
2012/07/20 Javascript
js截取中英文字符串、标点符号无乱码示例解读
2014/04/17 Javascript
js闭包实现按秒计数
2015/04/23 Javascript
详解js产生对象的3种基本方式(工厂模式,构造函数模式,原型模式)
2017/01/09 Javascript
jQuery倒计时代码(超简单)
2017/02/27 Javascript
详解nodejs微信公众号开发——5.素材管理接口
2017/04/11 NodeJs
React Native仿美团下拉菜单的实例代码
2017/08/08 Javascript
微信小程序实现人脸检测功能
2018/05/25 Javascript
详解微信小程序调起键盘性能优化
2018/07/24 Javascript
NodeJs之word文件生成与解析的实现代码
2019/04/01 NodeJs
JavaScript变速动画函数封装添加任意多个属性
2019/04/03 Javascript
利用Python如何生成便签图片详解
2018/07/09 Python
Python的bit_length函数来二进制的位数方法
2019/08/27 Python
使用Pandas的Series方法绘制图像教程
2019/12/04 Python
python实现opencv+scoket网络实时图传
2020/03/20 Python
Python matplotlib实时画图案例
2020/04/23 Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
2020/06/23 Python
如何在python中实现线性回归
2020/08/10 Python
Manjaro、pip、conda更换国内源的方法
2020/11/17 Python
全球才华横溢工匠的家居装饰、珠宝和礼物:NOVICA
2021/01/22 全球购物
公司委托书怎么写
2014/08/02 职场文书
2014年中学生检讨书大全
2014/10/09 职场文书
文明单位申报材料
2014/12/23 职场文书
导游词之云南省玉龙雪山
2019/12/19 职场文书
Ajax 的初步实现(使用vscode+node.js+express框架)
2021/06/18 Javascript
Python可视化神器pyecharts绘制水球图
2022/07/07 Python