Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现递归遍历文件夹并删除文件
Apr 18 Python
Python实现基于多线程、多用户的FTP服务器与客户端功能完整实例
Aug 18 Python
基于Python的文件类型和字符串详解
Dec 21 Python
PyQt弹出式对话框的常用方法及标准按钮类型
Feb 27 Python
使用Python-OpenCV向图片添加噪声的实现(高斯噪声、椒盐噪声)
May 28 Python
使用celery和Django处理异步任务的流程分析
Feb 19 Python
python之MSE、MAE、RMSE的使用
Feb 24 Python
解决Django中checkbox复选框的传值问题
Mar 31 Python
基于python实现操作git过程代码解析
Jul 27 Python
Python离线安装各种库及pip的方法
Nov 28 Python
python 批量压缩图片的脚本
Jun 02 Python
学会Python数据可视化必须尝试这7个库
Jun 16 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
PHP正则表达式之定界符和原子介绍
2012/10/05 PHP
精美漂亮的php分页类代码
2013/04/02 PHP
PHP邮件发送类PHPMailer用法实例详解
2014/09/22 PHP
php cli配置文件问题分析
2015/10/15 PHP
微信网页授权(OAuth2.0) PHP 源码简单实现
2016/08/29 PHP
PHP面向对象程序设计之类与反射API详解
2016/12/02 PHP
PDO::exec讲解
2019/01/28 PHP
PHP如何防止XSS攻击与XSS攻击原理的讲解
2019/03/22 PHP
PHP 计算两个时间段之间交集的天数示例
2019/10/24 PHP
JSQL 基于客户端的成绩统计实现方法
2010/05/05 Javascript
JQuery对checkbox操作 (循环获取)
2011/05/20 Javascript
Javascript和Java获取各种form表单信息的简单实例
2014/02/14 Javascript
js实现照片墙功能实例
2015/02/05 Javascript
javascript获取元素离文档各边距离的方法
2015/02/13 Javascript
Jquery Mobile 自定义按钮图标
2015/11/18 Javascript
前端js弹出框组件使用方法
2020/08/24 Javascript
JS简单判断字符在另一个字符串中出现次数的2种常用方法
2017/04/20 Javascript
JavaScript实现简单评论功能
2017/08/17 Javascript
动态Axios的配置步骤详解
2018/01/12 Javascript
vue项目中自定义video视频控制条的实现代码
2020/04/26 Javascript
在vue中使用el-tab-pane v-show/v-if无效的解决
2020/08/03 Javascript
浅谈python for循环的巧妙运用(迭代、列表生成式)
2017/09/26 Python
Pandas 对Dataframe结构排序的实现方法
2018/04/10 Python
Pandas之drop_duplicates:去除重复项方法
2018/04/18 Python
python ChainMap 合并字典的实现步骤
2019/06/11 Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
2020/06/10 Python
HTML5 新表单类型示例代码
2018/03/20 HTML / CSS
一封普通求职者的求职信
2013/11/20 职场文书
学校安全教育制度
2014/01/31 职场文书
办理房产过户的委托书
2014/09/14 职场文书
群众路线对照检查材料思想汇报怎么写
2014/09/18 职场文书
工厂标语大全
2014/10/06 职场文书
小学生暑假安全公约
2015/07/14 职场文书
评奖评优个人先进事迹材料
2015/11/04 职场文书
2016廉洁从政心得体会
2016/01/19 职场文书
mybatis3中@SelectProvider传递参数方式
2021/08/04 Java/Android