tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python让图片按照exif信息里的创建时间进行排序的方法
Mar 16 Python
Python闭包实现计数器的方法
May 05 Python
Python的地形三维可视化Matplotlib和gdal使用实例
Dec 09 Python
Python简单实现的代理服务器端口映射功能示例
Apr 08 Python
CentOS7下python3.7.0安装教程
Jul 30 Python
pyspark.sql.DataFrame与pandas.DataFrame之间的相互转换实例
Aug 02 Python
Python对切片命名的实现方法
Oct 16 Python
python模拟菜刀反弹shell绕过限制【推荐】
Jun 25 Python
Python用Try语句捕获异常的实例方法
Jun 26 Python
python变量的存储原理详解
Jul 10 Python
python求质数列表的例子
Nov 24 Python
python 进制转换 int、bin、oct、hex的原理
Jan 13 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php下批量挂马和批量清马代码
2011/02/27 PHP
PHP文件注释标记及规范小结
2012/04/01 PHP
PHP中的生成XML文件的4种方法分享
2012/10/06 PHP
php实现修改新闻时删除图片的方法
2015/05/12 PHP
PHP截取IE浏览器并缩小原图的方法
2016/03/04 PHP
PHP有序表查找之二分查找(折半查找)算法示例
2018/02/09 PHP
JS制作手机端自适应缩放显示
2015/06/11 Javascript
jquery滚动条插件slimScroll使用方法
2017/02/09 Javascript
微信小程序之GET请求的实例详解
2017/09/29 Javascript
vue2.0 中使用transition实现动画效果使用心得
2018/08/13 Javascript
js实现百度登录窗口拖拽效果
2020/03/19 Javascript
vue在响应头response中获取自定义headers操作
2020/07/24 Javascript
Vue实现图书管理小案例
2020/12/03 Vue.js
[38:21]2014 DOTA2国际邀请赛中国区预选赛5.21 TongFu VS LGD-CDEC
2014/05/22 DOTA
[00:10]神之谴戒
2019/03/06 DOTA
[42:06]2019国际邀请赛全明星赛 8.23
2019/09/05 DOTA
Django的session中对于用户验证的支持
2015/07/23 Python
Python实现身份证号码解析
2015/09/01 Python
Python字典的核心底层原理讲解
2019/01/24 Python
python格式化输出保留2位小数的实现方法
2019/07/02 Python
python点击鼠标获取坐标(Graphics)
2019/08/10 Python
Python脚本去除文件的只读性操作
2020/03/05 Python
Python中zipfile压缩文件模块的基本使用教程
2020/06/14 Python
增大python字体的方法步骤
2020/07/05 Python
使用CSS3 制作一个material-design 风格登录界面实例
2016/12/12 HTML / CSS
html5 Canvas画图教程(9)—canvas中画出矩形和圆形
2013/01/09 HTML / CSS
英国信箱在线鲜花速递公司:Bloom & Wild
2019/03/10 全球购物
英国外籍人士的在线超市:British Corner Shop
2019/06/03 全球购物
护士自荐信怎么写
2013/10/18 职场文书
求职信写作要突出重点
2014/01/01 职场文书
报关专员求职信范文
2014/02/22 职场文书
2014庆六一活动方案
2014/03/02 职场文书
校庆标语集锦
2014/06/25 职场文书
村级四风对照检查材料
2014/08/24 职场文书
详解Nginx 被动检查服务器的存活状态
2021/10/16 Servers
springboot新建项目pom.xml文件第一行报错的解决
2022/01/18 Java/Android