tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Tkinter GUI编程入门介绍
Mar 10 Python
Python比较两个图片相似度的方法
Mar 13 Python
python根据京东商品url获取产品价格
Aug 09 Python
深入解析Python中的线程同步方法
Jun 14 Python
python opencv检测目标颜色的实例讲解
Apr 02 Python
python批量修改图片后缀的方法(png到jpg)
Oct 25 Python
Python使用修饰器进行异常日志记录操作示例
Mar 19 Python
PyQt5实现从主窗口打开子窗口的方法
Jun 19 Python
python的scipy实现插值的示例代码
Nov 12 Python
Tensorflow tf.dynamic_partition矩阵拆分示例(Python3)
Feb 07 Python
Python动态导入模块:__import__、importlib、动态导入的使用场景实例分析
Mar 30 Python
利用4行Python代码监测每一行程序的运行时间和空间消耗
Apr 22 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
《PHP边学边教》(02.Apache+PHP环境配置――下篇)
2006/12/13 PHP
PHP语言中global和$GLOBALS[]的分析 之二
2012/02/02 PHP
解决nginx不支持thinkphp中pathinfo的问题
2015/07/21 PHP
PHP实现SMTP邮件的发送实例
2018/09/27 PHP
php常用字符串长度函数strlen()与mb_strlen()用法实例分析
2019/06/25 PHP
laravel实现Auth认证,登录、注册后的页面回跳方法
2019/09/30 PHP
任意位置显示html菜单
2007/02/01 Javascript
javascript Ext JS 状态默认存储时间
2009/02/15 Javascript
jQuery 自动增长的文本输入框实现代码
2010/04/02 Javascript
JSQL  一个 web DB 的封装
2010/05/05 Javascript
node.js发送邮件email的方法详解
2017/01/06 Javascript
vue动态生成dom并且自动绑定事件
2017/04/19 Javascript
详解如何将angular-ui的图片轮播组件封装成一个指令
2017/05/09 Javascript
vue中各组件之间传递数据的方法示例
2017/07/27 Javascript
vue-test-utils初使用详解
2019/05/23 Javascript
通过JS深度判断两个对象字段相同
2019/06/14 Javascript
小程序如何获取多个formId实现详解
2019/09/20 Javascript
如何构建 vue-ssr 项目的方法步骤
2020/08/04 Javascript
SpringBoot在yml配置文件中配置druid的操作
2020/11/16 Javascript
[58:18]2018DOTA2亚洲邀请赛3月29日 小组赛B组 iG VS Mineski
2018/03/30 DOTA
[42:32]完美世界DOTA2联赛循环赛 Magma vs PXG BO2第二场 10.28
2020/10/28 DOTA
Python 文件处理注意事项总结
2017/04/10 Python
python中使用正则表达式的后向搜索肯定模式(推荐)
2017/11/11 Python
Python中static相关知识小结
2018/01/02 Python
基于Python实现的ID3决策树功能示例
2018/01/02 Python
TensorFlow的权值更新方法
2018/06/14 Python
Python3使用pandas模块读写excel操作示例
2018/07/03 Python
大家都说好用的Python命令行库click的使用
2019/11/07 Python
Canvas高级路径操作之拖拽对象的实现
2019/08/05 HTML / CSS
奥地利手表、香水、化妆品和珠宝购物网站:Brasty.at
2021/01/17 全球购物
strstr()的简单实现
2013/09/26 面试题
一道写SQL的面试题和答案
2013/11/19 面试题
活动总结模板
2014/05/09 职场文书
财务工作失职检讨书
2014/11/21 职场文书
JS实现简单控制视频播放倍速的实例代码
2021/04/18 Javascript
Python实现Matplotlib,Seaborn动态数据图
2022/05/06 Python