tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析发往本机的数据包示例 (解析数据包)
Jan 16 Python
跟老齐学Python之折腾一下目录
Oct 24 Python
详解python中的装饰器
Jul 10 Python
Python多进程与服务器并发原理及用法实例分析
Aug 21 Python
导入tensorflow时报错:cannot import name 'abs'的解决
Oct 10 Python
python栈的基本定义与使用方法示例【初始化、赋值、入栈、出栈等】
Oct 24 Python
python脚本后台执行方式
Dec 21 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 Python
Django ModelForm组件原理及用法详解
Oct 12 Python
python基于opencv实现人脸识别
Jan 04 Python
深入理解python多线程编程
Apr 18 Python
Python 中数组和数字相乘时的注意事项说明
May 10 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php学习之 循环结构实现代码
2011/06/09 PHP
php的declare控制符和ticks教程(附示例)
2014/03/21 PHP
PHP中date与gmdate的区别及默认时区设置
2014/05/12 PHP
PHP正则表达式匹配替换与分割功能实例浅析
2017/02/04 PHP
JavaScript 编程引入命名空间的方法与代码
2007/08/13 Javascript
javascript OFFICE控件测试代码
2009/12/08 Javascript
封装了一个js图片轮换效果的函数
2011/09/28 Javascript
提交表单时执行func方法实现代码
2013/03/17 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
js编写一个简单的产品放大效果代码
2016/06/27 Javascript
layui文件上传实现代码
2017/05/20 Javascript
动态创建Angular组件实现popup弹窗功能
2017/09/15 Javascript
第一个Vue插件从封装到发布
2017/11/22 Javascript
NodeJs入门教程之定时器和队列
2019/03/08 NodeJs
微信小程序实现元素渐入渐出动画效果封装方法
2019/05/18 Javascript
javascript实现移动端上传图片功能
2020/08/18 Javascript
js实现移动端轮播图滑动切换
2020/12/21 Javascript
[36:16]完美世界DOTA2联赛PWL S3 access vs Rebirth 第一场 12.19
2020/12/24 DOTA
使用Python的Scrapy框架十分钟爬取美女图
2016/12/26 Python
运行django项目指定IP和端口的方法
2018/05/14 Python
Python将文本去空格并保存到txt文件中的实例
2018/07/24 Python
windows10下安装TensorFlow Object Detection API的步骤
2019/06/13 Python
Python 微信爬虫完整实例【单线程与多线程】
2019/07/06 Python
python实现银行账户系统
2021/02/22 Python
英国家庭珠宝商:T. H. Baker
2018/02/08 全球购物
BIBLOO波兰:捷克的一家在线服装店
2018/03/09 全球购物
亚洲领先的旅游体验市场:Voyagin
2019/11/23 全球购物
Ariat英国官网:为世界顶级马术运动员制造最优质的鞋类和服装
2020/02/14 全球购物
荷兰最大的鞋子、服装和运动折扣店:Bristol
2021/01/07 全球购物
PHP如何设置和取得Cookie值
2015/06/30 面试题
大学生创业感言
2014/01/25 职场文书
《藏戏》教学反思
2014/02/11 职场文书
基督教追悼会答谢词
2015/09/29 职场文书
Python离线安装openpyxl模块的步骤
2021/03/30 Python
Nginx中break与last的区别详析
2021/03/31 Servers
Redis中有序集合的内部实现方式的详细介绍
2022/03/16 Redis