tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
动态创建类实例代码
Oct 07 Python
python实现系统状态监测和故障转移实例方法
Nov 18 Python
详解python并发获取snmp信息及性能测试
Mar 27 Python
Python中read()、readline()和readlines()三者间的区别和用法
Jul 30 Python
python实现随机调用一个浏览器打开网页
Apr 21 Python
Python中捕获键盘的方式详解
Mar 28 Python
Python安装与基本数据类型教程详解
May 29 Python
python使用opencv对图像mask处理的方法
Jul 05 Python
PyQt 图解Qt Designer工具的使用方法
Aug 06 Python
Python利用PyPDF2库获取PDF文件总页码实例
Apr 03 Python
python实现在列表中查找某个元素的下标示例
Nov 16 Python
方法汇总:Python 安装第三方库常用
Apr 26 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php heredoc和phpwind的模板技术使用方法小结
2008/03/28 PHP
php对二维数组进行排序的简单实例
2013/12/19 PHP
PHP静态成员变量
2017/02/14 PHP
在Laravel 的 Blade 模版中实现定义变量
2019/10/14 PHP
比较全的JS checkbox全选、取消全选、删除功能代码
2008/12/19 Javascript
Mootools 1.2教程(3) 数组使用简介
2009/09/14 Javascript
javascript 模式设计之工厂模式详细说明
2010/05/10 Javascript
javascript中运用闭包和自执行函数解决大量的全局变量问题
2010/12/30 Javascript
基于jquery实现弹幕效果
2016/09/29 Javascript
angularJS模态框$modal实例代码
2017/05/27 Javascript
Async Validator 异步验证使用说明
2017/07/03 Javascript
浅谈vue-cli 3.0.x 初体验
2018/04/11 Javascript
JavaScript装箱及拆箱boxing及unBoxing用法解析
2020/06/15 Javascript
[16:04]DOTA2海涛带你玩炸弹 9月5日更新内容详解
2014/09/05 DOTA
[04:03][TI9趣味短片] 小鸽子茶话会
2019/08/20 DOTA
在Django的视图中使用数据库查询的方法
2015/07/16 Python
Python中死锁的形成示例及死锁情况的防止
2016/06/14 Python
python垃圾回收机制(GC)原理解析
2019/12/30 Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
2020/05/25 Python
Python爬虫代理池搭建的方法步骤
2020/09/28 Python
python3 通过 pybind11 使用Eigen加速代码的步骤详解
2020/12/07 Python
CSS3制作炫酷的下拉菜单及弹起式选单的实例分享
2016/05/17 HTML / CSS
10分钟理解CSS3 FlexBox弹性布局
2018/12/20 HTML / CSS
设计部经理的岗位职责
2013/11/16 职场文书
建筑项目策划书
2014/01/13 职场文书
九年级英语教学反思
2014/01/31 职场文书
军训自我鉴定范文
2014/02/13 职场文书
小学生综合素质评语
2014/04/23 职场文书
过程装备与控制工程专业求职信
2014/07/02 职场文书
国际贸易毕业生求职信
2014/07/20 职场文书
同志主要表现材料
2014/08/21 职场文书
重点工程汇报材料
2014/08/27 职场文书
防火标语大全
2014/10/06 职场文书
迟到检讨书2000字(精选篇)
2014/10/07 职场文书
2014年精神文明建设工作总结
2014/11/19 职场文书
2016教师校本研修心得体会
2016/01/08 职场文书