tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3访问并下载网页内容的方法
Jul 28 Python
分享6个隐藏的python功能
Dec 07 Python
python3.x上post发送json数据
Mar 04 Python
通过python爬虫赚钱的方法
Jan 29 Python
python异步存储数据详解
Mar 19 Python
python sort、sort_index方法代码实例
Mar 28 Python
在python tkinter中Canvas实现进度条显示的方法
Jun 14 Python
django 控制页面跳转的例子
Aug 06 Python
python线程信号量semaphore使用解析
Nov 30 Python
利用python绘制正态分布曲线
Jan 04 Python
Python中threading库实现线程锁与释放锁
May 17 Python
Python使用pandas导入csv文件内容的示例代码
Dec 24 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
PHP反转字符串函数strrev()函数的用法
2012/02/04 PHP
深入mysql_fetch_row()与mysql_fetch_array()的区别详解
2013/06/05 PHP
浅析THINKPHP的addAll支持的最大数据量
2015/02/03 PHP
基础的WordPress插件制作教程
2015/11/24 PHP
PHP入门教程之图像处理技巧分析
2016/09/11 PHP
PHP实现读取文件夹及批量重命名文件操作示例
2019/04/15 PHP
Prototype使用指南之selector.js
2007/01/10 Javascript
JS中setTimeout()的用法详解
2013/04/14 Javascript
flash遮住div问题的正确解决方法
2014/02/27 Javascript
jQuery中dom元素上绑定的事件详解
2015/04/24 Javascript
每天一篇javascript学习小结(RegExp对象)
2015/11/17 Javascript
js捕捉键盘事件和按键键值的方法
2016/10/10 Javascript
Bootstrap基本模板的使用和理解1
2016/12/14 Javascript
基于jQuery实现照片墙自动播放特效
2017/01/12 Javascript
Ionic + Angular.js实现验证码倒计时功能的方法
2017/06/12 Javascript
JS实现分页浏览横向图片(类轮播)实例代码
2017/11/06 Javascript
详细分析JS函数去抖和节流
2017/12/05 Javascript
jQuery基于Ajax实现读取XML数据功能示例
2018/05/31 jQuery
D3.js(v3)+react 实现带坐标与比例尺的散点图 (V3版本)
2019/05/09 Javascript
[50:05]VGJ.S vs OG 2018国际邀请赛淘汰赛BO3 第二场 8.22
2018/08/23 DOTA
Python实现并行抓取整站40万条房价数据(可更换抓取城市)
2016/12/14 Python
Python在不同目录下导入模块的实现方法
2017/10/27 Python
Python文本统计功能之西游记用字统计操作示例
2018/05/07 Python
Python3解释器知识点总结
2019/02/19 Python
django url到views参数传递的实例
2019/07/19 Python
python安装gdal的两种方法
2019/10/29 Python
python实现简单的购物程序代码实例
2020/03/03 Python
美国最大的高尔夫发球时间预订网站:TeeOff.com
2018/03/28 全球购物
Silk Therapeutics官网:清洁、抗衰老护肤品
2020/08/12 全球购物
说出一些常用的类,包,接口
2014/09/22 面试题
机械工程系毕业生求职信
2013/09/27 职场文书
外语专业毕业生自荐信
2014/04/14 职场文书
补充协议书
2015/01/28 职场文书
教师节简报
2015/07/20 职场文书
Python+Selenium实现读取网易邮箱验证码
2022/03/13 Python
python解析json数据
2022/04/29 Python