tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python分析apache访问日志脚本分享
Feb 26 Python
python提取内容关键词的方法
Mar 16 Python
Python将list中的string批量转化成int/float的方法
Jun 26 Python
详解python单元测试框架unittest
Jul 02 Python
Python爬虫之网页图片抓取的方法
Jul 16 Python
Django项目中添加ldap登陆认证功能的实现
Apr 04 Python
使用python画社交网络图实例代码
Jul 10 Python
Django 开发环境配置过程详解
Jul 18 Python
python BlockingScheduler定时任务及其他方式的实现
Sep 19 Python
区分python中的进程与线程
Aug 13 Python
使用bandit对目标python代码进行安全函数扫描的案例分析
Jan 27 Python
Python爬虫实战之爬取京东商品数据并实实现数据可视化
Jun 07 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
PHP下一个非常全面获取图象信息的函数
2008/11/20 PHP
使用JSON实现数据的跨域传输的php代码
2011/12/20 PHP
ThinkPHP学习笔记(一)ThinkPHP部署
2014/06/22 PHP
ThinkPHP实现将SESSION存入MYSQL的方法
2014/07/22 PHP
Codeigniter的一些优秀特性总结
2015/01/21 PHP
php创建多级目录的方法
2015/03/24 PHP
thinkphp3.2.3版本的数据库增删改查实现代码
2016/09/22 PHP
PHP创建文件及写入数据(覆盖写入,追加写入)的方法详解
2019/02/15 PHP
判断是否输入完毕再激活提交按钮
2006/06/26 Javascript
用JS剩余字数计算的代码
2008/07/03 Javascript
JS验证邮箱格式是否正确的代码
2013/12/05 Javascript
微信小程序技巧之show内容展示,上传文件编码问题
2017/01/23 Javascript
Django1.7+JQuery+Ajax验证用户注册集成小例子
2017/04/08 jQuery
jQueryMobile之窗体长内容的缺陷与解决方法实例分析
2017/09/20 jQuery
使用 vue-i18n 切换中英文效果
2018/05/23 Javascript
使用vue.js在页面内组件监听scroll事件的方法
2018/09/11 Javascript
vue封装一个简单的div框选时间的组件的方法
2019/01/06 Javascript
node.js express框架简介与实现
2019/07/23 Javascript
详解Vue2.5+迁移至Typescript指南
2019/08/01 Javascript
vue递归组件实战之简单树形控件实例代码
2019/08/27 Javascript
jQuery实现弹出层效果
2019/12/10 jQuery
node.js中fs文件系统模块的使用方法实例详解
2020/02/13 Javascript
使用Python向C语言的链接库传递数组、结构体、指针类型的数据
2019/01/29 Python
python是否适合网页编程详解
2019/10/04 Python
python带参数打包exe及调用方式
2019/12/21 Python
Django使用Celery加redis执行异步任务的实例内容
2020/02/20 Python
Jmeter HTTPS接口测试证书导入过程图解
2020/07/22 Python
requests在python中发送请求的实例讲解
2021/02/17 Python
python快速安装OpenCV的步骤记录
2021/02/22 Python
嘻哈珠宝品牌:KRKC&CO
2020/10/19 全球购物
宿舍卫生检讨书
2014/01/16 职场文书
广播节目策划方案
2014/05/23 职场文书
学校综治宣传月活动总结
2014/07/02 职场文书
企业财务总监岗位职责
2015/04/03 职场文书
员工规章制度范本
2015/08/07 职场文书
Python装饰器详细介绍
2022/03/25 Python