tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用pil生成图片验证码的方法
May 08 Python
Python设置Socket代理及实现远程摄像头控制的例子
Nov 13 Python
Request的中断和ErrorHandler实例解析
Feb 12 Python
详解Django中CBV(Class Base Views)模型源码分析
Feb 25 Python
python实现维吉尼亚算法
Mar 20 Python
anaconda如何查看并管理python环境
Jul 05 Python
Django 导出项目依赖库到 requirements.txt过程解析
Aug 23 Python
django 做 migrate 时 表已存在的处理方法
Aug 31 Python
Python读取YAML文件过程详解
Dec 30 Python
python如何调用百度识图api
Sep 29 Python
Python 处理表格进行成绩排序的操作代码
Jul 26 Python
python3操作redis实现List列表实例
Aug 04 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
收音机怀古---春雷3P7图片欣赏
2021/03/02 无线电
PHPlet在Windows下的安装
2006/10/09 PHP
PHP4中实现动态代理
2006/10/09 PHP
关于mysql字符集设置了character_set_client=binary 在gbk情况下会出现表描述是乱码的情况
2013/01/06 PHP
利用Homestead快速运行一个Laravel项目的方法详解
2017/11/14 PHP
Thinkphp5框架中引入Markdown编辑器操作示例
2020/06/03 PHP
javascript js cookie的存储,获取和删除
2007/12/29 Javascript
Javascript 表单之间的数据传递代码
2008/12/04 Javascript
JavaScript入门教程(2) JS基础知识
2009/01/31 Javascript
JavaScript中的面向对象介绍
2012/06/30 Javascript
node.js中的require使用详解
2014/12/15 Javascript
JS实现的网页倒计时数字时钟效果
2015/03/02 Javascript
Jquery异步提交表单代码分享
2015/03/26 Javascript
JavaScript 面向对象与原型
2015/04/10 Javascript
jQuery实现控制文字内容溢出用省略号(…)表示的方法
2016/02/26 Javascript
聊一聊JS中this的指向问题
2016/06/17 Javascript
jQuery中的siblings()是什么意思(推荐)
2016/12/29 Javascript
详解vue.js的事件处理器v-on:click
2017/06/27 Javascript
AngularJS中ng-class用法实例分析
2017/07/06 Javascript
React Native 使用Fetch发送网络请求的示例代码
2017/12/02 Javascript
vue+springboot前后端分离实现单点登录跨域问题解决方法
2018/01/30 Javascript
JavaScript链式调用实例浅析
2018/12/19 Javascript
Vue和React组件之间的传值方式详解
2019/01/31 Javascript
使用JQuery自动完成插件Auto Complete详解
2019/06/18 jQuery
[06:04]DOTA2国际邀请赛纪录片:Just For LGD
2013/08/11 DOTA
Python 制作糗事百科爬虫实例
2016/09/22 Python
K-近邻算法的python实现代码分享
2017/12/09 Python
canvas绘制树形结构可视图形的实现
2020/04/03 HTML / CSS
日本动漫周边服饰销售网站:Atsuko
2019/12/16 全球购物
购房意向书范本
2014/04/01 职场文书
创建文明城市标语
2014/06/16 职场文书
公共场所禁烟倡议书
2014/08/30 职场文书
个人委托书怎么写
2014/09/17 职场文书
mysql知识点整理
2021/04/05 MySQL
Python中os模块的简单使用及重命名操作
2021/04/17 Python
MySQL索引 高效获取数据的数据结构
2022/05/02 MySQL