tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的异常处理学习笔记
Jan 28 Python
Python3搜索及替换文件中文本的方法
May 22 Python
Python 遍历列表里面序号和值的方法(三种)
Feb 17 Python
linux安装Python3.4.2的操作方法
Sep 28 Python
解决django model修改添加字段报错的问题
Nov 18 Python
python3中rank函数的用法
Nov 27 Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 Python
Python实现井字棋小游戏
Mar 09 Python
Selenium获取登录Cookies并添加Cookies自动登录的方法
Dec 04 Python
pandas抽取行列数据的几种方法
Dec 13 Python
python 判断文件或文件夹是否存在
Mar 18 Python
python​格式化字符串
Apr 20 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
在字符串中把网址改成超级链接
2006/10/09 PHP
php socket方式提交的post详解
2008/07/19 PHP
php中时间函数date及常用的时间计算
2017/05/12 PHP
php实现支付宝当面付(扫码支付)功能
2018/05/30 PHP
浅谈PHP匿名函数和闭包
2019/03/08 PHP
javascript new 需不需要继续使用
2009/07/02 Javascript
defer属性导致引用JQuery的页面报“浏览器无法打开网站xxx,操作被中止”错误的解决方法
2010/04/27 Javascript
firefox火狐浏览器与与ie兼容的2个问题总结
2010/07/20 Javascript
javascript常用方法汇总
2014/12/02 Javascript
js实现密码强度检测【附示例】
2016/03/30 Javascript
jQuery基于ajax操作json数据简单示例
2017/01/05 Javascript
Java中int与integer的区别(基本数据类型与引用数据类型)
2017/02/19 Javascript
Vue表单验证插件Vue Validator使用方法详解
2017/04/07 Javascript
react-native fetch的具体使用方法
2017/11/01 Javascript
layui之select的option叠加问题的解决方法
2018/03/08 Javascript
webpack本地开发环境无法用IP访问的解决方法
2018/03/20 Javascript
浅谈angular2子组件的事件传递(任意组件事件传递)
2018/09/30 Javascript
js实现点赞按钮功能的实例代码
2020/03/06 Javascript
vue-router 控制路由权限的实现
2020/09/24 Javascript
python求列表交集的方法汇总
2014/11/10 Python
PyQt4实时显示文本内容GUI的示例
2019/06/14 Python
python的slice notation的特殊用法详解
2019/12/27 Python
Python实现AI自动抠图实例解析
2020/03/05 Python
Python 找出出现次数超过数组长度一半的元素实例
2020/05/11 Python
Django 允许局域网中的机器访问你的主机操作
2020/05/13 Python
python在协程中增加任务实例操作
2021/02/28 Python
HTML5对比HTML4的主要改变和改进总结
2016/05/27 HTML / CSS
HTML5 FileReader对象的具体使用方法
2020/05/22 HTML / CSS
美国大城市最热门旅游景点门票:CityPASS
2016/12/16 全球购物
世界各地的当地人的食物体验:Eatwith
2019/07/26 全球购物
自我鉴定四大框架
2014/01/17 职场文书
个性发展自我评价
2014/02/11 职场文书
汽车运用工程专业求职信
2014/06/18 职场文书
2014年政风行风自查自纠报告
2014/10/21 职场文书
CSS 文字装饰 text-decoration & text-emphasis 详解
2021/04/06 HTML / CSS
Mysql超详细讲解死锁问题的理解
2022/04/01 MySQL