tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
栈和队列数据结构的基本概念及其相关的Python实现
Aug 24 Python
Pycharm学习教程(4) Python解释器的相关配置
May 03 Python
Python随机生成手机号、数字的方法详解
Jul 21 Python
Python3数据库操作包pymysql的操作方法
Jul 16 Python
利用python循环创建多个文件的方法
Oct 25 Python
python遍历小写英文字母的方法
Jan 02 Python
Python利用lxml模块爬取豆瓣读书排行榜的方法与分析
Apr 15 Python
Python图片的横坐标汉字实例
Dec 04 Python
Python类如何定义私有变量
Feb 03 Python
基于python实现把json数据转换成Excel表格
May 07 Python
django使用多个数据库的方法实例
Mar 04 Python
Django程序的优化技巧
Apr 29 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
解析php下载远程图片函数 可伪造来路
2013/06/25 PHP
Linux下PHP连接Oracle数据库
2014/08/20 PHP
详解php框架Yaf路由重写
2017/06/20 PHP
Laravel 模型使用软删除-左连接查询-表起别名示例
2019/10/24 PHP
用jquery来定位
2007/02/20 Javascript
JQuery Ajax 跨域访问的解决方案
2010/03/12 Javascript
鼠标滚轮控制网页横向移动实现思路
2013/03/22 Javascript
跟我学Nodejs(一)--- Node.js简介及安装开发环境
2014/05/20 NodeJs
JS实现图片无间断滚动代码汇总
2014/07/30 Javascript
Node.js DES加密的简单实现
2016/07/07 Javascript
Angular路由简单学习
2016/12/26 Javascript
JS简单实现自定义右键菜单实例
2017/05/31 Javascript
JS通过调用微信API实现微信支付功能的方法示例
2017/06/29 Javascript
微信小程序自定义组件
2017/08/16 Javascript
如何抽象一个Vue公共组件
2017/10/17 Javascript
AngularJs 禁止模板缓存的方法
2017/11/28 Javascript
jQuery第一次运行页面默认触发点击事件的实例
2018/01/10 jQuery
浅谈vue中$event理解和框架中在包含默认值外传参
2020/08/07 Javascript
vue 页面跳转的实现方式
2021/01/12 Vue.js
[38:21]2018DOTA2亚洲邀请赛3月30日 小组赛A组 LGD VS Newbee
2018/03/31 DOTA
[01:31:02]TNC vs VG 2019国际邀请赛淘汰赛 胜者组赛BO3 第一场
2019/08/22 DOTA
Python中实现远程调用(RPC、RMI)简单例子
2014/04/28 Python
使用Python编写简单的画图板程序的示例教程
2015/12/08 Python
Python按行读取文件的实现方法【小文件和大文件读取】
2016/09/19 Python
python引入导入自定义模块和外部文件的实例
2017/07/24 Python
用Python写王者荣耀刷金币脚本
2017/12/21 Python
用python wxpy管理微信公众号并利用微信获取自己的开源数据
2019/07/30 Python
Django中modelform组件实例用法总结
2020/02/10 Python
详解向scrapy中的spider传递参数的几种方法(2种)
2020/09/28 Python
New Balance比利时官方网站:购买鞋子和服装
2021/01/15 全球购物
护士岗位职责
2014/02/16 职场文书
创业培训计划书
2014/05/03 职场文书
优秀班干部主要事迹材料
2015/11/04 职场文书
Go 自定义package包设置与导入操作
2021/05/06 Golang
Golang二维数组的使用方式
2021/05/28 Golang
讨论nginx location 顺序问题
2022/05/30 Servers