tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集合的关系
Sep 24 Python
在Python的Flask框架中使用模版的入门教程
Apr 20 Python
python获取元素在数组中索引号的方法
Jul 15 Python
30秒轻松实现TensorFlow物体检测
Mar 14 Python
对Python Class之间函数的调用关系详解
Jan 23 Python
opencv设置采集视频分辨率方式
Dec 10 Python
Python 写了个新型冠状病毒疫情传播模拟程序
Feb 14 Python
Python实现企业微信机器人每天定时发消息实例
Feb 25 Python
Python编程快速上手——正则表达式查找功能案例分析
Feb 28 Python
利用django model save方法对未更改的字段依然进行了保存
Mar 28 Python
django模板获取list中指定索引的值方式
May 14 Python
python 调用Google翻译接口的方法
Dec 09 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
使用NetBeans + Xdebug调试PHP程序的方法
2011/04/12 PHP
php中session使用示例
2014/03/29 PHP
Ext.get() 和 Ext.query()组合使用实现最灵活的取元素方式
2011/09/26 Javascript
showModalDialog在谷歌浏览器下会返回Null的解决方法
2013/11/27 Javascript
jquery禁用右键单击功能屏蔽F5刷新
2014/03/17 Javascript
JavaScript获取table中某一列的值的方法
2014/05/06 Javascript
avascript中的自执行匿名函数应用示例
2014/09/15 Javascript
JavaScript中switch语句的用法详解
2015/06/03 Javascript
浅谈JavaScript中变量和函数声明的提升
2016/08/09 Javascript
JS获取及验证开始结束日期的方法
2016/08/20 Javascript
js正则表达式验证表单【完整版】
2017/03/06 Javascript
JavaScript实现多重继承的方法分析
2018/01/09 Javascript
AngularJS上传文件的示例代码
2018/11/10 Javascript
利用原生JS实现data方法示例代码
2019/05/28 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
node删除、复制文件或文件夹示例代码
2019/08/13 Javascript
JS实现可视化音频效果的实例代码
2020/01/16 Javascript
JS hasOwnProperty()方法检测一个属性是否是对象的自有属性的方法
2021/01/29 Javascript
[01:00:30]完美世界DOTA2联赛循环赛 Inki vs Matador BO2第二场 10.31
2020/11/02 DOTA
python正则表达式修复网站文章字体不统一的解决方法
2013/02/21 Python
python使用append合并两个数组的方法
2015/04/28 Python
python 换位密码算法的实例详解
2017/07/19 Python
对python3中pathlib库的Path类的使用详解
2018/10/14 Python
bluepy 一款python封装的BLE利器简单介绍
2019/06/25 Python
Python 使用matplotlib模块模拟掷骰子
2019/08/08 Python
Anconda环境下Vscode安装Python的方法详解
2020/03/29 Python
Python urllib3软件包的使用说明
2020/11/18 Python
数据库基础的一些面试题
2012/02/25 面试题
Set里的元素是不能重复的,那么用什么方法来区分重复与否呢? 是用==还是equals()? 它们有何区别?
2014/07/27 面试题
中国文明网签名寄语
2014/01/18 职场文书
酒店节能降耗方案
2014/05/08 职场文书
企业职业病防治方案
2014/05/29 职场文书
乡镇干部个人整改措施思想汇报
2014/10/10 职场文书
市级三好学生评语
2014/12/29 职场文书
2015年学校管理工作总结
2015/07/20 职场文书
Qt数据库应用之实现图片转pdf
2022/06/01 Java/Android