tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中if __name__ == '__main__'作用解析
Jun 29 Python
python实现ping的方法
Jul 06 Python
python中的break、continue、exit()、pass全面解析
Aug 05 Python
Python代码实现KNN算法
Dec 20 Python
Python自定义函数实现求两个数最大公约数、最小公倍数示例
May 21 Python
Tesserocr库的正确安装方式
Oct 19 Python
使用PM2+nginx部署python项目的方法示例
Nov 07 Python
在python中实现强制关闭线程的示例
Jan 22 Python
Python 微信之获取好友昵称并制作wordcloud的实例
Feb 21 Python
Pandas之Fillna填充缺失数据的方法
Jun 25 Python
python ETL工具 pyetl
Jun 07 Python
图文详解matlab原始处理图像几何变换
Jul 09 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
PHP小技巧搜集,每个PHPer都来露一手
2007/01/02 PHP
php实现的SESSION类
2014/12/02 PHP
PHP生成随机数的方法实例分析
2015/01/22 PHP
php实现encode64编码类实例
2015/03/24 PHP
PHP生成唯一订单号的方法汇总
2015/04/16 PHP
用PHP代码在网页上生成图片
2015/07/01 PHP
Yii 访问 Gii(脚手架)时出现 403 错误
2018/06/06 PHP
js怎么覆盖原有方法实现重写
2014/09/04 Javascript
浅析javascript中函数声明和函数表达式的区别
2015/02/15 Javascript
通过js获取上传的图片信息(临时保存路径,名称,大小)然后通过ajax传递给后端的方法
2015/10/01 Javascript
vue.js学习笔记:如何加载本地json文件
2017/01/17 Javascript
Cookies 和 Session的详解及区别
2017/04/21 Javascript
js使用generator函数同步执行ajax任务
2017/09/05 Javascript
Vue引入sass并配置全局变量的方法
2018/06/27 Javascript
vuex提交state&amp;&amp;实时监听state数据的改变方法
2018/09/16 Javascript
Vue.js 中的 v-cloak 指令及使用详解
2018/11/19 Javascript
vue-cli3 项目从搭建优化到docker部署的方法
2019/01/28 Javascript
vue19 组建 Vue.extend component、组件模版、动态组件 的实例代码
2019/04/04 Javascript
vue-iview动态新增和删除的方法
2020/06/17 Javascript
微信小程序上传帖子的实例代码(含有文字图片的微信验证)
2020/07/11 Javascript
axios封装与传参示例详解
2020/10/18 Javascript
[39:21]LGD vs OG 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.24
2019/09/10 DOTA
详解Python中find()方法的使用
2015/05/18 Python
python解析含有重复key的json方法
2019/01/22 Python
总结Python图形用户界面和游戏开发知识点
2019/05/22 Python
解决Django加载静态资源失败的问题
2019/07/28 Python
Python过滤掉numpy.array中非nan数据实例
2020/06/08 Python
Python正则re模块使用步骤及原理解析
2020/08/18 Python
珍珠鸟教学反思
2014/02/01 职场文书
社区八一活动方案
2014/02/03 职场文书
村主任群众路线个人对照检查材料
2014/09/26 职场文书
当幸福来敲门观后感
2015/06/01 职场文书
运动会开幕式新闻稿
2015/07/17 职场文书
地震捐款简报
2015/07/21 职场文书
Java后台生成图片的完整步骤
2021/08/04 Java/Android
Dubbo+zookeeper搭配分布式服务的过程详解
2022/04/03 Java/Android