Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python人人网登录应用实例
Sep 26 Python
python中list循环语句用法实例
Nov 10 Python
Python兔子毒药问题实例分析
Mar 05 Python
python实现下载文件的三种方法
Feb 09 Python
Python 闭包的使用方法
Sep 07 Python
记一次python 内存泄漏问题及解决过程
Nov 29 Python
django echarts饼图数据动态加载的实例
Aug 12 Python
Pycharm中安装wordcloud等库失败问题及终端通过pip安装的Python库如何添加到Pycharm解释器中(推荐)
May 10 Python
解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题
Jun 01 Python
Pycharm操作Git及GitHub的步骤详解
Oct 27 Python
python Timer 类使用介绍
Dec 28 Python
python中filter,map,reduce的作用
Jun 10 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
PHP 模板高级篇总结
2006/12/21 PHP
在php中判断一个请求是ajax请求还是普通请求的方法
2011/06/28 PHP
php实现登陆模块功能示例
2016/10/20 PHP
javascript 变量作用域 代码分析
2009/06/26 Javascript
解决火狐浏览器下JS setTimeout函数不兼容失效不执行的方法
2012/11/14 Javascript
extjs tabpanel限制选项卡数量实现思路及代码
2013/04/02 Javascript
js字符串转换成xml对象并使用技巧解读
2013/04/18 Javascript
使用js画图之圆、弧、扇形
2015/01/12 Javascript
jquery使用remove()方法删除指定class子元素
2015/03/26 Javascript
CSS或者JS实现鼠标悬停显示另一元素
2016/01/22 Javascript
jstl中判断list中是否包含某个值的简单方法
2016/10/14 Javascript
JS鼠标滚动分页效果示例
2017/07/05 Javascript
Vue 多层组件嵌套二种实现方式(测试实例)
2017/09/08 Javascript
Vue.js用法详解
2017/11/13 Javascript
浅析JS中回调函数及用法
2018/07/25 Javascript
关于JavaScript 数组你应该知道的事情(推荐)
2019/04/10 Javascript
小程序异步问题之多个网络请求依次执行并依次收集请求结果
2019/05/05 Javascript
使用Vant完成DatetimePicker 日期的选择器操作
2020/11/12 Javascript
[04:52]2015国际邀请赛LGD战队晋级之路
2015/08/14 DOTA
Python比较文件夹比另一同名文件夹多出的文件并复制出来的方法
2015/03/05 Python
Python实现将目录中TXT合并成一个大TXT文件的方法
2015/07/15 Python
Python实现UDP程序通信过程图解
2020/05/15 Python
摩飞电器俄罗斯官方网站:Morphy Richards俄罗斯
2020/07/30 全球购物
简述安装Slackware Linux系统的过程
2012/01/12 面试题
简历中自我评价分享
2013/10/09 职场文书
铁路个人事迹材料
2014/01/30 职场文书
《山谷中的谜底》教学反思
2014/04/26 职场文书
英语复习计划
2015/01/19 职场文书
优秀团员自我评价
2015/03/10 职场文书
2015年乡镇纪检工作总结
2015/04/22 职场文书
义卖募捐活动总结
2015/05/09 职场文书
工作简报怎么写
2015/07/21 职场文书
2016教师年度考核评语大全
2015/12/01 职场文书
关于艺术节的开幕致辞
2016/03/04 职场文书
Python 实现Mac 屏幕截图详解
2021/10/05 Python
基于docker安装zabbix的详细教程
2022/06/05 Servers