Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过正则查找微博@(at)用户的方法
Mar 13 Python
Python字符串替换实例分析
May 11 Python
Python实现简单文本字符串处理的方法
Jan 22 Python
python中ASCII码字符与int之间的转换方法
Jul 09 Python
Python+selenium 获取浏览器窗口坐标、句柄的方法
Oct 14 Python
Flask框架路由和视图用法实例分析
Nov 07 Python
python实现word文档批量转成自定义格式的excel文档的思路及实例代码
Feb 21 Python
python中数据库like模糊查询方式
Mar 02 Python
django实现日志按日期分割
May 21 Python
python输入中文的实例方法
Sep 14 Python
python Pexpect模块的使用
Dec 25 Python
python实现web邮箱扫描的示例(附源码)
Mar 30 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
使用字符串函数输出整数化的PHP版本号
2006/10/09 PHP
PHP开发工具ZendStudio下Xdebug工具使用说明详解
2013/11/11 PHP
学习PHP Cookie处理函数
2016/08/09 PHP
php基于curl主动推送最新内容给百度收录的方法
2016/10/14 PHP
PHP解析url并得到url参数方法总结
2018/10/11 PHP
php使用json-schema模块实现json校验示例
2019/09/28 PHP
ppk谈JavaScript style属性
2008/10/10 Javascript
node.js中的console.trace方法使用说明
2014/12/09 Javascript
JavaScript入门基础
2015/08/12 Javascript
JavaScript事件学习小结(三)js事件对象
2016/06/09 Javascript
Angular的模块化(代码分享)
2016/12/26 Javascript
js实现产品缩略图效果
2017/03/10 Javascript
vue中用H5实现文件上传的方法实例代码
2017/05/27 Javascript
详细AngularJs4的图片剪裁组件的实例
2017/07/12 Javascript
解决VUE框架 导致绑定事件的阻止冒泡失效问题
2018/02/24 Javascript
vue js秒转天数小时分钟秒的实例代码
2018/08/08 Javascript
取消Bootstrap的dropdown-menu点击默认关闭事件方法
2018/08/10 Javascript
浅谈Vue.js 中的 v-on 事件指令的使用
2018/11/25 Javascript
微信小程序如何使用云开发
2019/05/17 Javascript
[02:32]DOTA2英雄基础教程 祸乱之源
2013/12/23 DOTA
[02:04]2014DOTA2国际邀请赛 BBC小组赛第三天总结
2014/07/12 DOTA
利用python获得时间的实例说明
2013/03/25 Python
Python基于DES算法加密解密实例
2015/06/03 Python
详解Django中的过滤器
2015/07/16 Python
python常用知识梳理(必看篇)
2017/03/23 Python
python 限制函数调用次数的实例讲解
2018/04/21 Python
Python Datetime模块和Calendar模块用法实例分析
2019/04/15 Python
django用户登录验证的完整示例代码
2019/07/21 Python
wxPython色环电阻计算器
2019/11/18 Python
Django 自动生成api接口文档教程
2019/11/19 Python
电大毕业个人生自我鉴定
2014/03/26 职场文书
《大禹治水》教学反思
2014/04/27 职场文书
六一儿童节标语
2014/10/08 职场文书
大学毕业生个人总结
2015/02/28 职场文书
计划生育目标责任书
2015/05/09 职场文书
《我的美好婚事》动画化决定纪念插画与先导PV公开
2022/04/06 日漫