解决Keras TensorFlow 混编中 trainable=False设置无效问题


Posted in Python onJune 28, 2020

这是最近碰到一个问题,先描述下问题:

首先我有一个训练好的模型(例如vgg16),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()进行权值更新,问题就出在这,明明将vgg16的模型设置为trainable=False,但是tf的优化器仍然对vgg16做权值更新

以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。

trainable=False 无效

首先,我们导入训练好的模型vgg16,对其设置成trainable=False

from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:
  layer.trainable = False

设置好trainable=False后,再次查看可训练的变量,发现并没有变化,也就是说设置无效

# 再次查看可训练的变量
tf.trainable_variables()

[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量

from keras import models
with tf.variable_scope('base_model'):
  base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):
  model = models.Sequential()
  model.add(base_model)
  model.add(layers.Flatten())
  model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var

[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,
<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]

# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

总结

在keras与TensorFlow混编中,keras中设置trainable=False对于TensorFlow而言并不起作用

解决的办法就是通过variable_scope对变量进行区分,在通过tf.get_collection来获取需要训练的变量,最后通过tf优化器中var_list指定训练

以上这篇解决Keras TensorFlow 混编中 trainable=False设置无效问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
比较详细Python正则表达式操作指南(re使用)
Sep 06 Python
python 中的divmod数字处理函数浅析
Oct 17 Python
python中hashlib模块用法示例
Oct 30 Python
Python判断文件和字符串编码类型的实例
Dec 21 Python
Python实现可设置持续运行时间、线程数及时间间隔的多线程异步post请求功能
Jan 11 Python
python3调用R的示例代码
Feb 23 Python
python GUI库图形界面开发之PyQt5工具栏控件QToolBar的详细使用方法与实例
Feb 28 Python
详解django使用include无法跳转的解决方法
Mar 19 Python
如何表示python中的相对路径
Jul 08 Python
python 多线程中join()的作用
Oct 29 Python
Django自带的用户验证系统实现
Dec 18 Python
Python中tqdm的使用和例子
Sep 23 Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
You might like
ThinkPHP的截取字符串函数无法显示省略号的解决方法
2014/06/25 PHP
PHP SFTP实现上传下载功能
2017/07/26 PHP
OAuth认证协议中的HMACSHA1加密算法(实例)
2017/10/25 PHP
Linux下源码包安装Swoole及基本使用操作图文详解
2019/04/02 PHP
使Ext的Template可以解析二层的json数据的方法
2007/12/22 Javascript
js控制div及网页相关属性的代码
2009/12/19 Javascript
一段批量给页面上的控件赋值js
2010/06/19 Javascript
模拟jQuery中的ready方法及实现按需加载css,js实例代码
2013/09/27 Javascript
JS方法调用括号的问题探讨
2014/01/24 Javascript
easyui Draggable组件实现拖动效果
2015/08/19 Javascript
javascript实现可键盘控制的抽奖系统
2016/03/10 Javascript
jquery实现表格中点击相应行变色功能效果【实例代码】
2016/05/09 Javascript
jQuery实现鼠标滑过预览图片大图效果的方法
2017/04/26 jQuery
Vue 组件间的样式冲突污染
2017/08/31 Javascript
MVVM 双向绑定的实现代码
2018/06/21 Javascript
jQuery实现弹出层效果
2019/12/10 jQuery
[14:21]VICI vs EG (BO3)
2018/06/07 DOTA
[52:26]完美世界DOTA2联赛决赛 FTD vs Phoenix 第一场 11.08
2020/11/11 DOTA
闭包在python中的应用之translate和maketrans用法详解
2014/08/27 Python
python命令行参数解析OptionParser类用法实例
2014/10/09 Python
Python实现的生产者、消费者问题完整实例
2018/05/30 Python
使用Python进行QQ批量登录的实例代码
2018/06/11 Python
python小程序实现刷票功能详解
2019/07/17 Python
pycharm创建scrapy项目教程及遇到的坑解析
2019/08/15 Python
用python写测试数据文件过程解析
2019/09/25 Python
Jupyter notebook如何修改平台字体
2020/05/13 Python
python 发送邮件的四种方法汇总
2020/12/02 Python
Canvas系列之滤镜效果
2019/02/12 HTML / CSS
英国最大的奢侈珠宝和手表网站:C W Sellors
2017/02/10 全球购物
德国家具购物网站:Möbel Höffner
2019/08/26 全球购物
主题酒店策划书
2014/01/28 职场文书
大学英语专业求职信
2014/06/21 职场文书
中职生求职信
2014/07/01 职场文书
2014年个人售房协议书
2014/10/30 职场文书
5.12护士节活动总结
2015/02/10 职场文书
MYSQL 的10大经典优化案例场景实战
2021/09/14 MySQL