tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 不关闭控制台的实现方法
Oct 23 Python
Python定时执行之Timer用法示例
May 27 Python
浅谈Python的垃圾回收机制
Dec 17 Python
分享几道你可能遇到的python面试题
Jul 24 Python
一文总结学习Python的14张思维导图
Oct 17 Python
python用BeautifulSoup库简单爬虫实例分析
Jul 30 Python
python基于opencv检测程序运行效率
Dec 28 Python
Pandas时间序列:时期(period)及其算术运算详解
Feb 25 Python
python 实现人和电脑猜拳的示例代码
Mar 02 Python
python如何提升爬虫效率
Sep 27 Python
python读取pdf格式文档的实现代码
Apr 01 Python
Python语言中的数据类型-序列
Feb 24 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
基于php无限分类的深入理解
2013/06/02 PHP
浅析虚拟主机服务器php fsockopen函数被禁用的解决办法
2013/08/07 PHP
PHP查看SSL证书信息的方法
2016/09/22 PHP
PHP实现的解汉诺塔问题算法示例
2018/08/06 PHP
jquery last-child 列表最后一项的样式
2010/01/22 Javascript
Script的加载方法小结
2011/01/12 Javascript
判断是否安装flash player及当前版本的JS代码
2013/08/08 Javascript
JQuery遍历json数组的3种方法
2014/11/08 Javascript
jQuery+jRange实现滑动选取数值范围特效
2015/03/14 Javascript
JS实现网页右侧带动画效果的伸缩窗口代码
2015/10/29 Javascript
JavaScript 事件对内存和性能的影响
2017/01/22 Javascript
微信小程序 JS动态修改样式的实现代码
2017/02/10 Javascript
Vue如何实现组件的源码解析
2017/06/08 Javascript
浅谈JS封闭函数、闭包、内置对象
2017/07/18 Javascript
Vue2.0基于vue-cli+webpack父子组件通信(实例讲解)
2017/09/14 Javascript
基于JavaScript实现幸运抽奖页面
2020/07/05 Javascript
vue.extend与vue.component的区别和联系
2018/09/19 Javascript
bootstrap中的导航条实例代码详解
2019/05/20 Javascript
[48:35]2018DOTA2亚洲邀请赛 4.1 小组赛 A组加赛 TNC vs Optic
2018/04/03 DOTA
Python易忽视知识点小结
2015/05/25 Python
Python脚本简单实现打开默认浏览器登录人人和打开QQ的方法
2016/04/12 Python
Python数据分析之真实IP请求Pandas详解
2016/11/18 Python
Windows下Anaconda的安装和简单使用方法
2018/01/04 Python
python将.ppm格式图片转换成.jpg格式文件的方法
2018/10/27 Python
python矩阵的转置和逆转实例
2018/12/12 Python
python二维码操作:对QRCode和MyQR入门详解
2019/06/24 Python
基于python实现从尾到头打印链表
2019/11/02 Python
如何将Pycharm中调整字体大小的方式设置为"ctrl+鼠标滚轮上下滑"
2020/11/17 Python
Python运算符+与+=的方法实例
2021/02/18 Python
Python 求向量的余弦值操作
2021/03/04 Python
德国香水、化妆品和护理产品网上商店:Parfumdreams
2018/09/26 全球购物
新员工试用期工作总结2015
2015/05/28 职场文书
老干部座谈会主持词
2015/07/03 职场文书
mysql5.7使用binlog 恢复数据的方法
2021/06/03 MySQL
详解Java ES多节点任务的高效分发与收集实现
2021/06/30 Java/Android
Go gorilla securecookie库的安装使用详解
2022/08/14 Golang