tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python语言编写电脑时间自动同步小工具
Mar 08 Python
详解Python中的__new__()方法的使用
Apr 09 Python
python实现对excel进行数据剔除操作实例
Dec 07 Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 Python
Python日期时间对象转换为字符串的实例
Jun 22 Python
对python插入数据库和生成插入sql的示例讲解
Nov 14 Python
python 将dicom图片转换成jpg图片的实例
Jan 13 Python
python对象销毁实例(垃圾回收)
Jan 16 Python
python3正则模块re的使用方法详解
Feb 11 Python
Python configparser模块配置文件过程解析
Mar 03 Python
深入了解NumPy 高级索引
Jul 24 Python
分享python函数常见关键字
Apr 26 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
使用openssl实现rsa非对称加密算法示例
2014/01/24 PHP
PHP通过加锁实现并发情况下抢码功能
2016/08/10 PHP
PHP PDOStatement::errorCode讲解
2019/01/31 PHP
JQuery切换显示的效果实例代码
2013/02/27 Javascript
jQuery Ajax异步处理Json数据详解
2013/11/05 Javascript
node.js中使用q.js实现api的promise化
2014/09/17 Javascript
JavaScript构造函数详解
2015/12/27 Javascript
js获取当前页的URL与window.location.href简单方法
2017/02/13 Javascript
angular动态删除ng-repaeat添加的dom节点的方法
2017/07/20 Javascript
Vue + Vue-router 同名路由切换数据不更新的方法
2017/11/20 Javascript
解决vue 项目引入字体图标报错、不显示等问题
2018/09/01 Javascript
基于javascript的拖拽类封装详解
2019/04/19 Javascript
微信小程序+云开发实现欢迎登录注册
2019/05/24 Javascript
javascript创建元素和删除元素实例小结
2019/06/19 Javascript
react实现同页面三级跳转路由布局
2019/09/26 Javascript
Python的迭代器和生成器
2015/07/29 Python
python3实现UDP协议的服务器和客户端
2017/06/14 Python
Python通过调用mysql存储过程实现更新数据功能示例
2018/04/03 Python
python求最大连续子数组的和
2018/07/07 Python
flask入门之表单的实现
2018/07/18 Python
利用pandas合并多个excel的方法示例
2019/10/10 Python
python实现简单井字棋小游戏
2020/03/05 Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
2020/06/12 Python
基于python图书馆管理系统设计实例详解
2020/08/05 Python
html5使用canvas实现跟随光标跳动的火焰效果
2014/01/07 HTML / CSS
HTML5基于flash实现播放RTMP协议视频的示例代码
2020/12/04 HTML / CSS
考博自荐信
2013/10/25 职场文书
工商技校毕业生自荐信
2013/11/15 职场文书
村干部培训方案
2014/05/02 职场文书
关于奉献的演讲稿
2014/05/21 职场文书
预备党员公开承诺书
2014/05/28 职场文书
网吧员工管理制度
2015/08/05 职场文书
廉洁自律准则学习心得体会
2016/01/13 职场文书
解决goland 导入项目后import里的包报红问题
2021/05/06 Golang
经典《舰娘》游改全新动画预告 预定11月开播
2022/04/01 日漫
V Rising 服务器搭建图文教程
2022/06/16 Servers