tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python struct.unpack
Sep 06 Python
centos下更新Python版本的步骤
Feb 12 Python
Python中实现远程调用(RPC、RMI)简单例子
Apr 28 Python
Python os模块中的isfile()和isdir()函数均返回false问题解决方法
Feb 04 Python
python3 与python2 异常处理的区别与联系
Jun 19 Python
python中nan与inf转为特定数字方法示例
May 11 Python
python中字符串比较使用is、==和cmp()总结
Mar 18 Python
django输出html内容的实例
May 27 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 Python
python 利用Pyinstaller打包Web项目
Oct 23 Python
Python根据字符串调用函数过程解析
Nov 05 Python
Python中Permission denied的解决方案
Apr 02 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
Fleaphp常见函数功能与用法示例
2016/11/15 PHP
PHP制作登录异常ip检测功能的实例代码
2016/11/16 PHP
Laravel框架创建路由的方法详解
2019/09/04 PHP
JavaScript具有类似Lambda表达式编程能力的代码(改进版)
2010/09/14 Javascript
Wordpress ThickBox 点击图片显示下一张图的修改方法
2010/12/11 Javascript
javascript禁用Tab键脚本实例
2013/11/22 Javascript
jquery获取URL中参数解决中文乱码问题的两种方法
2013/12/18 Javascript
jquery实现下拉菜单的二级联动利用json对象从DB取值显示联动
2014/03/27 Javascript
jQuery实现滚动鼠标放大缩小图片的方法(附demo源码下载)
2016/03/05 Javascript
拖动时防止选中
2017/02/03 Javascript
JS中使用正则表达式g模式和非g模式的区别
2017/04/01 Javascript
.net MVC+Bootstrap下使用localResizeIMG上传图片
2017/04/21 Javascript
jQuery Tree Multiselect使用详解
2017/05/02 jQuery
js如何编写简单的ajax方法库
2017/08/02 Javascript
JS生成随机打乱数组的方法示例
2017/12/23 Javascript
[原创]jQuery实现合并/追加数组并去除重复项的方法
2018/04/11 jQuery
vue 父组件给子组件传值子组件给父组件传值的实例代码
2019/04/15 Javascript
详解vuejs2.0 select 动态绑定下拉框支持多选
2019/04/25 Javascript
[01:03:50]DOTA2-DPC中国联赛 正赛 CDEC vs DLG BO3 第二场 2月7日
2021/03/11 DOTA
python 文件和路径操作函数小结
2009/11/23 Python
Python之csv文件从MySQL数据库导入导出的方法
2018/06/21 Python
解决python ogr shp字段写入中文乱码的问题
2018/12/31 Python
使用pycharm在本地开发并实时同步到服务器
2019/08/02 Python
django框架用户权限中的session缓存到redis中的方法
2019/08/06 Python
利用Python脚本批量生成SQL语句
2020/03/04 Python
深入了解canvas在移动端绘制模糊的问题解决
2019/04/30 HTML / CSS
意大利时尚奢侈品店:D’Aniello Boutique
2021/01/19 全球购物
巴西最大的运动品牌:Olympikus
2020/07/14 全球购物
全陪导游欢迎词
2014/01/17 职场文书
房屋买卖委托书格式范本格式
2014/10/13 职场文书
2014年幼儿园保育工作总结
2014/12/02 职场文书
2015年全国爱眼日活动方案
2015/05/05 职场文书
先进党支部事迹材料2016
2016/02/26 职场文书
幼儿园小班教学反思
2016/03/03 职场文书
导游词之上饶龟峰
2019/10/25 职场文书
假如给我三天光明:舟逆水而行,人遇挫而达 
2019/10/29 职场文书