Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过函数属性实现全局变量的方法
May 16 Python
matplotlib在python上绘制3D散点图实例详解
Dec 09 Python
python实现嵌套列表平铺的两种方法
Nov 08 Python
Python线性拟合实现函数与用法示例
Dec 13 Python
Python中文件的写入读取以及附加文字方法
Jan 23 Python
python中时间转换datetime和pd.to_datetime详析
Aug 11 Python
Python字符串中添加、插入特定字符的方法
Sep 10 Python
Python Sympy计算梯度、散度和旋度的实例
Dec 06 Python
python实现TCP文件传输
Mar 20 Python
Python yield生成器和return对比代码实例
Apr 20 Python
Python操作word文档插入图片和表格的实例演示
Oct 25 Python
详解Python常用的魔法方法
Jun 03 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
php实现利用phpexcel导出数据
2013/08/24 PHP
php中Session的生成机制、回收机制和存储机制探究
2014/08/19 PHP
kindeditor 加入七牛云上传的实例讲解
2017/11/12 PHP
php pdo连接数据库操作示例
2019/11/18 PHP
jquery $.ajax入门应用一
2008/11/19 Javascript
javascript面向对象的方式实现的弹出层效果代码
2010/01/28 Javascript
javascript权威指南 学习笔记之javascript数据类型
2011/09/24 Javascript
jQuery获取注册信息并提示实现代码
2013/04/21 Javascript
jquery弹出框的用法示例(一)
2013/08/26 Javascript
Js与下拉列表处理问题解决
2014/02/13 Javascript
我的NodeJs学习小结(一)
2014/07/06 NodeJs
javascript解析json实例详解
2014/11/05 Javascript
JavaScript实现的一个倒计时的类
2015/03/12 Javascript
理解AngularJs指令
2015/12/10 Javascript
Bootstrap Paginator分页插件使用方法详解
2016/05/30 Javascript
Bootstrap表单布局样式源代码
2016/07/04 Javascript
JavaScript判断浏览器及其版本信息
2017/01/20 Javascript
基于JavaScript实现带数据验证和复选框的表单提交
2017/08/23 Javascript
webpack实用小功能介绍
2018/01/02 Javascript
CentOS环境中MySQL修改root密码方法
2018/01/07 Javascript
mpvue 单文件页面配置详解
2018/12/02 Javascript
vue图片上传组件使用详解
2019/12/23 Javascript
详解在IDEA中将Echarts引入web两种方式(使用js文件和maven的依赖导入)
2020/07/11 Javascript
深入了解Vue.js 混入(mixins)
2020/07/23 Javascript
python+selenium+autoit实现文件上传功能
2017/08/23 Python
Python实现KNN邻近算法
2021/01/28 Python
Django csrf 验证问题的实现
2018/10/09 Python
PyQT5 QTableView显示绑定数据的实例详解
2019/06/25 Python
python归并排序算法过程实例讲解
2020/11/04 Python
css3实例教程 一款纯css3实现的环形导航菜单
2014/10/20 HTML / CSS
关于环保的建议书400字
2014/03/12 职场文书
幼儿园小班教师个人工作总结
2015/02/06 职场文书
2016年全国爱牙日宣传活动总结
2016/04/05 职场文书
MYSQL 运算符总结
2021/11/11 MySQL
Python探索生命起源 matplotlib细胞自动机动画演示
2022/04/21 Python
Python查找算法的实现 (线性、二分,分块、插值查找算法)
2022/04/24 Python