tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
编写Python脚本来获取mp3文件tag信息的教程
May 04 Python
Python找出9个连续的空闲端口
Feb 01 Python
python文件与目录操作实例详解
Feb 22 Python
浅析python协程相关概念
Jan 20 Python
使用python脚本实现查询火车票工具
Jul 19 Python
对python字典过滤条件的实例详解
Jan 22 Python
wxpython多线程防假死与线程间传递消息实例详解
Dec 13 Python
Python处理mysql特殊字符的问题
Mar 02 Python
matplotlib相关系统目录获取方式小结
Feb 03 Python
Python破解极验滑动验证码详细步骤
May 21 Python
Python音乐爬虫完美绕过反爬
Aug 30 Python
python区块链实现简版工作量证明
May 25 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php 购物车的例子
2009/05/04 PHP
php微信公众号js-sdk开发应用
2016/11/28 PHP
鼠标移入移出事件改变图片的分辨率的两种方法
2013/12/17 Javascript
禁止IE用右键的JS代码
2013/12/30 Javascript
js使用split函数按照多个字符对字符串进行分割的方法
2015/03/20 Javascript
基于jquery实现放大镜效果
2015/08/17 Javascript
基于bootstrap插件实现autocomplete自动完成表单
2016/05/07 Javascript
简单的js表格操作
2016/09/24 Javascript
jQuery网页定位导航特效实现方法
2016/12/19 Javascript
JS组件系列之JS组件封装过程详解
2017/04/28 Javascript
Vue中img的src属性绑定与static文件夹实例
2017/05/18 Javascript
Vue中保存用户登录状态实例代码
2017/06/07 Javascript
jQuery复合事件结合toggle()方法的用法示例
2017/06/10 jQuery
Node.JS中快速扫描端口并发现局域网内的Web服务器地址(80)
2017/09/18 Javascript
JS面向对象编程基础篇(一) 对象和构造函数实例详解
2020/03/03 Javascript
python 数据加密代码
2008/12/24 Python
Python3实现将文件归档到zip文件及从zip文件中读取数据的方法
2015/05/22 Python
使用Python内置的模块与函数进行不同进制的数的转换
2016/03/12 Python
Python3.8中使用f-strings调试
2019/05/22 Python
Django上使用数据可视化利器Bokeh解析
2019/07/31 Python
利用Tensorflow构建和训练自己的CNN来做简单的验证码识别方式
2020/01/20 Python
在python中logger setlevel没有生效的解决
2020/02/21 Python
使用python采集Excel表中某一格数据
2020/05/14 Python
html5 利用canvas手写签名并保存的实现方法
2018/07/12 HTML / CSS
俄罗斯运动鞋商店:Sneakerhead
2018/05/10 全球购物
德国内衣、泳装和睡衣网上商店:Bigsize Dessous
2018/07/09 全球购物
Godiva巧克力英国官网:比利时歌帝梵巧克力
2018/08/28 全球购物
美国在线购物频道:Shop LC
2019/04/21 全球购物
澳大利亚电商Catch新西兰站:Catch.co.nz
2020/05/30 全球购物
下述程序的作用是计算机数组中的最大元素值及其下标
2012/11/26 面试题
幼儿园托班开学寄语
2014/01/18 职场文书
运动会广播稿500字
2014/01/28 职场文书
财经学院自荐信范文
2014/02/02 职场文书
党的群众路线教育实践活动党员个人整改措施
2014/10/27 职场文书
2015年高中班主任工作总结
2015/04/30 职场文书
大学生入党自我鉴定范文
2019/06/21 职场文书