解决Keras TensorFlow 混编中 trainable=False设置无效问题


Posted in Python onJune 28, 2020

这是最近碰到一个问题,先描述下问题:

首先我有一个训练好的模型(例如vgg16),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()进行权值更新,问题就出在这,明明将vgg16的模型设置为trainable=False,但是tf的优化器仍然对vgg16做权值更新

以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。

trainable=False 无效

首先,我们导入训练好的模型vgg16,对其设置成trainable=False

from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:
  layer.trainable = False

设置好trainable=False后,再次查看可训练的变量,发现并没有变化,也就是说设置无效

# 再次查看可训练的变量
tf.trainable_variables()

[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量

from keras import models
with tf.variable_scope('base_model'):
  base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):
  model = models.Sequential()
  model.add(base_model)
  model.add(layers.Flatten())
  model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var

[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,
<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]

# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

总结

在keras与TensorFlow混编中,keras中设置trainable=False对于TensorFlow而言并不起作用

解决的办法就是通过variable_scope对变量进行区分,在通过tf.get_collection来获取需要训练的变量,最后通过tf优化器中var_list指定训练

以上这篇解决Keras TensorFlow 混编中 trainable=False设置无效问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现识别相似图片小结
Feb 22 Python
python 时间信息“2018-02-04 18:23:35“ 解析成字典形式的结果代码详解
Apr 19 Python
Python3多线程操作简单示例
May 22 Python
python matplotlib库直方图绘制详解
Aug 10 Python
python定位xpath 节点位置的方法
Aug 27 Python
python解析多层json操作示例
Dec 30 Python
flask利用flask-wtf验证上传的文件的方法
Jan 17 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 Python
Python基础之字典常见操作经典实例详解
Feb 26 Python
用Python 执行cmd命令
Dec 18 Python
django 认证类配置实现
Nov 11 Python
PYTHON基于Pyecharts绘制常见的直角坐标系图表
Apr 28 Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
You might like
PHP 函数学习简单小结
2010/07/08 PHP
解析php中如何直接执行SHELL
2013/06/28 PHP
php中用memcached实现页面防刷新功能
2014/08/19 PHP
YII Framework框架教程之日志用法详解
2016/03/14 PHP
PHP中CheckBox多选框上传失败的代码写法
2017/02/13 PHP
PHP实现验证码校验功能
2017/11/16 PHP
浅谈PHP之ThinkPHP框架使用详解
2020/07/21 PHP
jQuery 改变CSS样式基础代码
2010/02/11 Javascript
EasyUI 中 MenuButton 的使用方法
2012/07/14 Javascript
js遍历、动态的添加数据的小例子
2013/06/22 Javascript
JS onmousemove鼠标移动坐标接龙DIV效果实例
2013/12/16 Javascript
jquery实现侧边弹出的垂直导航
2014/12/09 Javascript
windows8.1+iis8.5下安装node.js开发环境
2014/12/12 Javascript
如何用angularjs制作一个完整的表格
2016/01/21 Javascript
动态更新highcharts数据的实现方法
2016/05/28 Javascript
微信小程序 Windows2008 R2服务器配置TLS1.2方法
2016/12/05 Javascript
详解vue slot插槽的使用方法
2017/06/13 Javascript
vue表单自定义校验规则介绍
2018/08/28 Javascript
5分钟教你用nodeJS手写一个mock数据服务器的方法
2019/09/10 NodeJs
[01:19:35]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#2Fnatic VS OG第二局
2016/03/05 DOTA
[00:32]2018DOTA2亚洲邀请赛VG出场
2018/04/03 DOTA
python 美化输出信息的实例
2018/10/15 Python
pyspark操作MongoDB的方法步骤
2019/01/04 Python
python对文件目录的操作方法实例总结
2019/06/24 Python
python多线程扫描端口(线程池)
2019/09/04 Python
给Django Admin添加验证码和多次登录尝试限制的实现
2020/07/26 Python
size?瑞典:英国伦敦的球鞋精品店
2018/03/01 全球购物
Sperry澳大利亚官网:源自美国帆船鞋创始品牌
2019/07/29 全球购物
网络教育毕业生自我鉴定
2013/10/10 职场文书
2015大学生党员自我评价范文
2015/03/03 职场文书
Python list去重且保持原顺序不变的方法
2021/04/03 Python
还在手动盖楼抽奖?教你用Python实现自动评论盖楼抽奖(一)
2021/06/07 Python
Python实现单例模式的5种方法
2021/06/15 Python
详解Java分布式事务的 6 种解决方案
2021/06/26 Java/Android
MySQL数据库10秒内插入百万条数据的实现
2021/11/01 MySQL
最新最全的手机号验证正则表达式
2022/02/24 Javascript