tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现截屏的函数
Jul 25 Python
Python获取文件所在目录和文件名的方法
Jan 12 Python
Python使用PyCrypto实现AES加密功能示例
May 22 Python
Python 3实战爬虫之爬取京东图书的图片详解
Oct 09 Python
python如何统计代码运行的时长
Jul 24 Python
决策树剪枝算法的python实现方法详解
Sep 18 Python
8段用于数据清洗Python代码(小结)
Oct 31 Python
Python3连接Mysql8.0遇到的问题及处理步骤
Feb 17 Python
python中的django是做什么的
Jul 31 Python
解决阿里云邮件发送不能使用25端口问题
Aug 07 Python
linux系统下pip升级报错的解决方法
Jan 31 Python
python爬虫之利用selenium模块自动登录CSDN
Apr 22 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
PHP中=赋值操作符对不同数据类型的不同行为
2011/01/02 PHP
php控制linux服务器常用功能 关机 重启 开新站点等
2012/09/05 PHP
php+highchats生成动态统计图
2014/05/21 PHP
php简单获取目录列表的方法
2015/03/24 PHP
用jQuery实现检测浏览器及版本的脚本代码
2008/01/22 Javascript
javascript笔试题目附答案@20081025_jb51.net
2008/10/26 Javascript
jQuery入门第一课 jQuery选择符
2010/03/14 Javascript
juqery 学习之四 筛选过滤
2010/11/30 Javascript
最新28个很棒的jQuery 教程
2011/05/28 Javascript
JS读取cookies信息(记录用户名)
2012/01/10 Javascript
JSONP 跨域访问代理API-yahooapis实现代码
2012/12/02 Javascript
jquery高效反选具体实现
2013/05/05 Javascript
js replace 与replaceall实例用法详解
2013/08/03 Javascript
javascript实现十秒钟后注册按钮可点击的方法
2015/05/13 Javascript
JS 实现随机验证码功能
2017/02/15 Javascript
详解webpack和webpack-simple中如何引入css文件
2017/06/28 Javascript
vue.js实现点击后动态添加class及删除同级class的实现代码
2018/04/04 Javascript
vue2.0 element-ui中el-select选择器无法显示选中的内容(解决方法)
2018/08/24 Javascript
解决vue-router在同一个路由下切换,取不到变化的路由参数问题
2018/09/01 Javascript
vue中的过滤器实例代码详解
2019/06/06 Javascript
js实现坦克移动小游戏
2019/10/28 Javascript
[02:48]DOTA2英雄基础教程 拉席克
2013/12/12 DOTA
[43:53]OG vs EG 2019国际邀请赛淘汰赛 胜者组 BO3 第三场 8.22
2019/09/05 DOTA
Python使用三种方法实现PCA算法
2017/12/12 Python
Python自然语言处理 NLTK 库用法入门教程【经典】
2018/06/26 Python
react+django清除浏览器缓存的几种方法小结
2019/07/17 Python
Python进度条的制作代码实例
2019/08/31 Python
Keras 在fit_generator训练方式中加入图像random_crop操作
2020/07/03 Python
celery在python爬虫中定时操作实例讲解
2020/11/27 Python
HTML5验证以及日期显示的实现详解
2013/07/05 HTML / CSS
德国宠物用品、宠物食品及水族馆网上商店:ZooRoyal
2017/07/09 全球购物
英国复古和经典球衣网站:Vintage Football Shirts
2018/10/05 全球购物
杭州龙健科技笔试题.net部分笔试题
2016/01/24 面试题
研究生个人学年总结
2015/02/14 职场文书
学校中秋节活动总结
2015/03/23 职场文书
导游词之藏龙百瀑景区
2019/12/30 职场文书