tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取图片颜色信息的方法
Mar 18 Python
介绍Python中内置的itertools模块
Apr 29 Python
python实现域名系统(DNS)正向查询的方法
Apr 19 Python
使用PyV8在Python爬虫中执行js代码
Feb 16 Python
python+Splinter实现12306抢票功能
Sep 25 Python
python opencv读mp4视频的实例
Dec 07 Python
解决Django加载静态资源失败的问题
Jul 28 Python
django 解决扩展自带User表遇到的问题
May 14 Python
Python捕获异常堆栈信息的几种方法(小结)
May 18 Python
在pycharm中创建django项目的示例代码
May 28 Python
Django mysqlclient安装和使用详解
Sep 17 Python
python unittest单元测试的步骤分析
Aug 02 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php正则过滤html标签、空格、换行符的代码(附说明)
2010/10/25 PHP
PHP实现手机归属地查询API接口实现代码
2012/08/27 PHP
ThinkPHP行为扩展Behavior应用实例详解
2014/07/22 PHP
初识laravel5
2015/03/02 PHP
PHP 5.6.11中CURL模块问题的解决方法
2016/08/08 PHP
laravel解决迁移文件一次删除创建字段报错的问题
2019/10/24 PHP
js 小贴士一星期合集
2010/04/07 Javascript
用jquery实现自定义风格的滑动条实现代码
2011/04/26 Javascript
JavaScript中的call方法和apply方法使用对比
2015/08/12 Javascript
jQuery学习心得总结(必看篇)
2016/06/10 Javascript
AngularJS下对数组的对比分析
2016/08/24 Javascript
JavaScript排序算法动画演示效果的实现方法
2016/10/18 Javascript
Javascript highcharts 饼图显示数量和百分比实例代码
2016/12/06 Javascript
ES2015 Symbol 一种绝不重复的值
2016/12/25 Javascript
AngularJS动态绑定ng-options的ng-model实例代码
2017/06/21 Javascript
Javacript中自定义的map.js  的方法
2017/11/26 Javascript
在knockoutjs 上自己实现的flux(实例讲解)
2017/12/18 Javascript
vue props对象validator自定义函数实例
2019/11/13 Javascript
基于vue.js实现购物车
2020/01/15 Javascript
JS自定义滚动条效果
2020/03/13 Javascript
vue实现编辑器键盘抬起时内容跟随光标距顶位置向上滚动效果
2020/05/28 Javascript
[05:49]2014DOTA2TI4正赛第二日综述 昔日冠军纷纷落马 VG LGD占尽先机
2014/07/20 DOTA
Python Web框架Flask下网站开发入门实例
2015/02/08 Python
python学习笔记之调用eval函数出现invalid syntax错误问题
2015/10/18 Python
Django中针对基于类的视图添加csrf_exempt实例代码
2018/02/11 Python
Django中url的反向查询的方法
2018/03/14 Python
一份python入门应该看的学习资料
2018/04/11 Python
python实现Windows电脑定时关机
2018/06/20 Python
解决Python安装时报缺少DLL问题【两种解决方法】
2019/07/15 Python
3种方式实现瀑布流布局小结
2019/09/05 HTML / CSS
美国健康和保健平台:healtop
2020/07/02 全球购物
.NET面试10题
2014/02/24 面试题
艺术系应届生的自我评价
2013/10/19 职场文书
街道党工委党的群众路线教育实践活动对照检查材料思想汇报
2014/10/05 职场文书
营销计划书
2015/01/17 职场文书
Python Flask请求扩展与中间件相关知识总结
2021/06/11 Python