tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中正则表达式的用法实例汇总
Aug 18 Python
python3操作微信itchat实现发送图片
Feb 24 Python
Python DataFrame设置/更改列表字段/元素类型的方法
Jun 09 Python
对python3 一组数值的归一化处理方法详解
Jul 11 Python
Python基于递归算法求最小公倍数和最大公约数示例
Jul 27 Python
Python面向对象程序设计类变量与成员变量、类方法与成员方法用法分析
Apr 12 Python
Python 异常处理Ⅳ过程图解
Oct 18 Python
Python跑循环时内存泄露的解决方法
Jan 13 Python
Django单元测试中Fixtures的使用方法
Feb 26 Python
QT5 Designer 打不开的问题及解决方法
Aug 20 Python
python如何绘制疫情图
Sep 16 Python
详解使用scrapy进行模拟登陆三种方式
Feb 21 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
详解PHP归并排序的实现
2016/10/18 PHP
让firefox支持IE的一些方法的javascript扩展函数代码
2010/01/02 Javascript
js实现GridView单选效果自动设置交替行、选中行、鼠标移动行背景色
2010/05/27 Javascript
js FLASH幻灯片字符串中有连接符&的处理方法
2012/03/01 Javascript
Jquery上传插件 uploadify v3.1使用说明
2012/06/18 Javascript
windows系统下简单nodejs安装及环境配置
2013/01/08 NodeJs
jqGrid增加时--判断开始日期与结束日期(实例解析)
2013/11/08 Javascript
form表单action提交的js部分与html部分
2014/01/07 Javascript
js中精确计算加法和减法示例
2014/03/28 Javascript
通过伪协议解决父页面与iframe页面通信的问题
2015/04/05 Javascript
使用jquery+CSS3实现仿windows10开始菜单的下拉导航菜单特效
2015/09/24 Javascript
如何使用PHP+jQuery+MySQL实现异步加载ECharts地图数据(附源码下载)
2016/02/23 Javascript
Javascript实现通过选择周数显示开始日和结束日的实现代码
2016/05/30 Javascript
vue模板语法-插值详解
2017/03/06 Javascript
javascript checkbox/radio onchange不能兼容ie8处理办法
2017/06/13 Javascript
JavaScript无操作后屏保功能的实现方法
2017/07/04 Javascript
ES6解构赋值的功能与用途实例分析
2017/10/31 Javascript
vue实现循环切换动画
2018/10/17 Javascript
微信小程序实现分享朋友圈的图片功能示例
2019/01/18 Javascript
「中高级前端面试」JavaScript手写代码无敌秘籍(推荐)
2019/04/08 Javascript
[00:15]TI9观赛名额抽取
2019/07/10 DOTA
linux系统使用python监控apache服务器进程脚本分享
2014/01/15 Python
python实现上传下载文件功能
2020/11/19 Python
Flask和Django框架中自定义模型类的表名、父类相关问题分析
2018/07/19 Python
判断python字典中key是否存在的两种方法
2018/08/10 Python
在django admin中添加自定义视图的例子
2019/07/26 Python
使用Python项目生成所有依赖包的清单方式
2020/07/13 Python
CSS3 box-sizing属性
2009/04/17 HTML / CSS
Exoticca英国:以最优惠的价格提供豪华异国情调旅行
2018/10/18 全球购物
中学生在校期间的自我评价分享
2013/11/13 职场文书
《夹竹桃》教学反思
2014/04/20 职场文书
4s店销售经理岗位职责
2014/07/19 职场文书
优秀教师先进材料
2014/12/16 职场文书
工作自我推荐信范文
2015/03/25 职场文书
爱的教育读书笔记
2015/06/26 职场文书
Javascript 解构赋值详情
2021/11/17 Javascript