Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python生成器实现微线程编程的教程
Apr 13 Python
使用Python的Twisted框架编写简单的网络客户端
Apr 16 Python
Python中用于转换字母为小写的lower()方法使用简介
May 19 Python
Python两个内置函数 locals 和globals(学习笔记)
Aug 28 Python
深入理解Python3中的http.client模块
Mar 29 Python
python 环境变量和import模块导入方法(详解)
Jul 11 Python
Python针对给定列表中元素进行翻转操作的方法分析
Apr 27 Python
Python 使用PIL numpy 实现拼接图片的示例
May 08 Python
使用pip安装python库的多种方式
Jul 31 Python
Python容器使用的5个技巧和2个误区总结
Sep 26 Python
Python如何用re模块实现简易tokenizer
May 02 Python
python+pyhyper实现识别图片中的车牌号思路详解
Dec 24 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
php mysql 判断update之后是否更新了的方法
2012/01/10 PHP
php中session与cookie的比较
2015/01/27 PHP
PHP微信开发之有道翻译
2016/06/23 PHP
laravel请求参数校验方法
2019/10/10 PHP
dojo学习第二天 ajax异步请求之绑定列表
2011/08/29 Javascript
jQuery实现获取h1-h6标题元素值的方法
2017/03/06 Javascript
微信小程序左右滑动的实现代码
2017/12/15 Javascript
关于axios如何全局注册浅析
2018/01/14 Javascript
js中的 || 与 && 运算符详解
2018/05/24 Javascript
vue elementUI tree树形控件获取父节点ID的实例
2018/09/12 Javascript
vue页面切换过渡transition效果
2018/10/08 Javascript
微信小程序自定义头部导航栏和导航栏背景图片 navigationStyle问题
2019/07/26 Javascript
Node Express用法详解【安装、使用、路由、中间件、模板引擎等】
2020/05/13 Javascript
Vue路由权限控制解析
2020/11/09 Javascript
关于JavaScript中异步/等待的用法与理解
2020/11/18 Javascript
python3使用tkinter实现ui界面简单实例
2014/01/10 Python
python实现BackPropagation算法
2017/12/14 Python
深入理解Django自定义信号(signals)
2018/10/15 Python
Python HTML解析模块HTMLParser用法分析【爬虫工具】
2019/04/05 Python
python实现京东订单推送到测试环境,提供便利操作示例
2019/08/09 Python
python禁用键鼠与提权代码实例
2019/08/16 Python
python生成器推导式用法简单示例
2019/10/08 Python
如何修复使用 Python ORM 工具 SQLAlchemy 时的常见陷阱
2019/11/19 Python
在Python 的线程中运行协程的方法
2020/02/24 Python
如何使用python代码操作git代码
2020/02/29 Python
Django实现whoosh搜索引擎使用jieba分词
2020/04/08 Python
MANGO官方网站:西班牙芒果服装品牌
2017/01/15 全球购物
英国快时尚女装购物网站:PrettyLittleThing
2018/08/15 全球购物
Java中实现多态的机制
2015/08/09 面试题
企业治理工作自我评价
2013/09/26 职场文书
升职自荐信范文
2013/10/05 职场文书
农业生产宣传标语
2014/10/08 职场文书
暑期辅导班宣传单
2015/07/14 职场文书
2016年社区“我们的节日·中秋节”活动总结
2016/04/05 职场文书
Mysql数据库命令大全
2021/05/26 MySQL
基于Python编写简易版的天天跑酷游戏的示例代码
2022/03/23 Python