tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能示例
May 16 Python
Python基于pandas实现json格式转换成dataframe的方法
Jun 22 Python
python实现n个数中选出m个数的方法
Nov 13 Python
python利用Tesseract识别验证码的方法示例
Jan 21 Python
详解Python给照片换底色(蓝底换红底)
Mar 22 Python
Python浮点数四舍五入问题的分析与解决方法
Nov 19 Python
python 实现简单的FTP程序
Dec 27 Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 Python
Python调用接口合并Excel表代码实例
Mar 31 Python
基于python实现地址和经纬度转换
May 19 Python
关于tf.matmul() 和tf.multiply() 的区别说明
Jun 18 Python
python爬虫构建代理ip池抓取数据库的示例代码
Sep 22 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
根德Grundig S400/S500/S700电路分析
2021/03/02 无线电
php中用加号与用array_merge合并数组的区别深入分析
2013/06/03 PHP
php class中public,private,protected的区别以及实例分析
2013/06/18 PHP
codeigniter使用技巧批量插入数据实例方法分享
2013/12/31 PHP
PHP封装的多文件上传类实例与用法详解
2017/02/07 PHP
解决Laravel5.5下的toArray问题
2019/10/15 PHP
Centos7安装swoole扩展操作示例
2020/03/26 PHP
关于JavaScript中原型继承中的一点思考
2012/07/25 Javascript
Js判断参数(String,Array,Object)是否为undefined或者值为空
2013/11/04 Javascript
JavaScript实现点击按钮复制指定区域文本(推荐)
2016/11/25 Javascript
微信小程序 ecshop地址三级联动实现实例代码
2017/02/28 Javascript
详解Vue使用命令行搭建单页面应用
2017/05/24 Javascript
jquery加载单文件vue组件的方法
2017/06/20 jQuery
Vue.js实现分页查询功能
2020/11/15 Javascript
JavaScript实现浅拷贝与深拷贝的方法分析
2018/07/05 Javascript
JQuery获得内容和属性方法解析
2020/05/30 jQuery
修改Vue打包后的默认文件名操作
2020/08/12 Javascript
python通过装饰器检查函数参数数据类型的方法
2015/03/13 Python
Python运用于数据分析的简单教程
2015/03/27 Python
Apache如何部署django项目
2017/05/21 Python
Python函数参数操作详解
2018/08/03 Python
使用urllib库的urlretrieve()方法下载网络文件到本地的方法
2018/12/19 Python
Python实现二维曲线拟合的方法
2018/12/29 Python
Python利用字典破解WIFI密码的方法
2019/02/27 Python
Python中url标签使用知识点总结
2020/01/16 Python
python 基于opencv去除图片阴影
2021/01/26 Python
老生常谈CSS中的长度单位
2016/06/27 HTML / CSS
资生堂美国官网:Shiseido美国
2016/09/02 全球购物
Steve Madden官网:美国鞋类品牌
2017/01/29 全球购物
香港彩色隐形眼镜在线商店:Stunninglens(全球免费送货)
2019/05/10 全球购物
Linux常见面试题
2013/03/18 面试题
公司同意接收函
2014/01/13 职场文书
英语求职信范文
2014/05/23 职场文书
幼儿园端午节活动方案
2014/08/25 职场文书
2015年党员个人工作总结
2015/05/13 职场文书
Python绘制分类图的方法
2021/04/20 Python