tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python编写脚本使IE实现代理上网的教程
Apr 23 Python
在Windows服务器下用Apache和mod_wsgi配置Python应用的教程
May 06 Python
详解Python的collections模块中的deque双端队列结构
Jul 07 Python
Python和C/C++交互的几种方法总结
May 11 Python
Python字符编码与函数的基本使用方法
Sep 30 Python
windows下添加Python环境变量的方法汇总
May 14 Python
深入浅析Python的类
Jun 22 Python
python银行系统实现源码
Oct 25 Python
wxPython窗体拆分布局基础组件
Nov 19 Python
python基于opencv检测程序运行效率
Dec 28 Python
pytorch之Resize()函数具体使用详解
Feb 27 Python
python百行代码自制电脑端网速悬浮窗的实现
May 12 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
『PHP』PHP截断函数mb_substr()使用介绍
2013/04/22 PHP
php使用正则过滤js脚本代码实例
2014/05/10 PHP
php session的锁和并发
2016/01/22 PHP
thinkphp中的url跳转用法分析
2016/07/12 PHP
浅谈php中urlencode与rawurlencode的区别
2016/09/05 PHP
Laravel路由研究之domain解决多域名问题的方法示例
2019/04/04 PHP
javascript object array方法使用详解
2012/12/03 Javascript
jQuery使用数组编写图片无缝向左滚动
2012/12/11 Javascript
javascript中onclick(this)用法介绍
2013/04/19 Javascript
JavaScript实现同一页面内两个表单互相传值的方法
2015/08/12 Javascript
JavaScript如何获取数组最大值和最小值
2015/11/18 Javascript
JavaScript电子时钟倒计时
2016/01/09 Javascript
xmlplus组件设计系列之路由(ViewStack)(7)
2017/05/02 Javascript
vue弹窗消息组件的使用方法
2020/09/24 Javascript
浅谈JavaScript 代码整洁之道
2018/10/23 Javascript
Vue唯一可以更改vuex实例中state数据状态的属性对象Mutation的讲解
2019/01/18 Javascript
JS使用正则表达式实现常用的表单验证功能分析
2020/04/30 Javascript
Django中几种重定向方法
2015/04/28 Python
详解Python中映射类型(字典)操作符的概念和使用
2015/08/19 Python
Python正则表达式非贪婪、多行匹配功能示例
2017/08/08 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
2017/10/29 Python
Python内置函数——__import__ 的使用方法
2017/11/24 Python
Python pandas如何向excel添加数据
2020/05/22 Python
Python使用itcaht库实现微信自动收发消息功能
2020/07/13 Python
Python爬虫设置ip代理过程解析
2020/07/20 Python
Selenium alert 弹窗处理的示例代码
2020/08/06 Python
使用HTML5拍照示例代码
2013/08/06 HTML / CSS
缅甸网上购物:Shop.com.mm
2017/12/05 全球购物
个人实用简单的自我评价
2013/10/19 职场文书
如何写一份好的自荐信
2014/01/02 职场文书
初中生思想道德自我评价
2015/03/09 职场文书
2015年中秋节主持词
2015/07/30 职场文书
2015年小学语文教师工作总结
2015/10/23 职场文书
2016年习总书记讲话学习心得体会
2016/01/20 职场文书
python实现的人脸识别打卡系统
2021/05/08 Python
Python内置包对JSON文件数据进行编码和解码
2022/04/12 Python