tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
Python如何实现MySQL实例初始化详解
Nov 06 Python
python分治法求二维数组局部峰值方法
Apr 03 Python
python smtplib模块自动收发邮件功能(一)
May 22 Python
Python实现拷贝/删除文件夹的方法详解
Aug 29 Python
用sqlalchemy构建Django连接池的实例
Aug 29 Python
Pytorch在NLP中的简单应用详解
Jan 08 Python
在jupyter notebook 添加 conda 环境的操作详解
Apr 10 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
Python 如何对文件目录操作
Jul 10 Python
基于Pytorch版yolov5的滑块验证码破解思路详解
Feb 25 Python
基于Python 函数和方法的区别说明
Mar 24 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
什么是短波收听SWL
2021/03/01 无线电
PhpMyAdmin中无法导入sql文件的解决办法
2010/01/08 PHP
CodeIgniter框架钩子机制实现方法【hooks类】
2018/08/21 PHP
PHP7 foreach() 函数修改
2021/03/09 PHP
form表单中去掉默认的enter键提交并绑定js方法实现代码
2013/04/01 Javascript
jQuery移动web开发之页面跳转和加载外部页面的实现
2015/12/04 Javascript
怎么限制input的text里输入的值只能是数字(正则、js)
2016/05/16 Javascript
JavaScript性能优化总结之加载与执行
2016/08/11 Javascript
jQuery仿写百度百科的目录树
2017/01/03 Javascript
jQuery实现联动下拉列表查询框
2017/01/04 Javascript
详解让sublime text3支持Vue语法高亮显示的示例
2017/09/29 Javascript
微信小程序 input输入及动态设置按钮的实现
2017/10/27 Javascript
layui.js实现的表单验证功能示例
2017/11/15 Javascript
vue proxy 的优势与使用场景实现
2020/06/15 Javascript
JavaScript缺少insertAfter解决方案
2020/07/03 Javascript
[42:20]Winstrike vs VGJ.S 2018国际邀请赛淘汰赛BO3 第二场 8.23
2018/08/24 DOTA
[49:58]完美世界DOTA2联赛PWL S3 Magma vs DLG 第一场 12.18
2020/12/19 DOTA
[01:23:24]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Elephant BO3 第三场 2月7日
2021/03/11 DOTA
Python Queue模块详解
2014/11/30 Python
python去重,一个由dict组成的list的去重示例
2019/01/21 Python
Python实现去除列表中重复元素的方法总结【7种方法】
2019/02/16 Python
Numpy对数组的操作:创建、变形(升降维等)、计算、取值、复制、分割、合并
2019/08/28 Python
基于pycharm实现批量修改变量名
2020/06/02 Python
解决python图像处理图像赋值后变为白色的问题
2020/06/04 Python
python 利用Pyinstaller打包Web项目
2020/10/23 Python
解决HTML5手机端页面缩放的问题
2017/10/27 HTML / CSS
html5通过postMessage进行跨域通信的方法
2017/12/04 HTML / CSS
Michael Kors加拿大官网:购买设计师手袋、手表、鞋子、服装等
2019/03/16 全球购物
js正则匹配markdown里的图片标签的实现
2021/03/24 Javascript
历史学专业大学生找工作的自我评价
2013/10/16 职场文书
党员大会主持词
2014/04/02 职场文书
商场租赁意向书
2014/07/30 职场文书
企业法人授权委托书范本
2014/09/23 职场文书
工厂门卫岗位职责
2015/04/13 职场文书
2015年教务主任工作总结
2015/07/22 职场文书
python中redis包操作数据库的教程
2022/04/19 Python