TensorFlow 模型载入方法汇总(小结)


Posted in Python onJune 19, 2018

一、TensorFlow常规模型加载方法

保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法

参数名称 功能说明 默认值
var_list Saver中存储变量集合 全局变量集合
reshape 加载时是否恢复变量形状 True
sharded 是否将变量轮循放在所有设备上 True
max_to_keep 保留最近检查点个数 5
restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

TensorFlow 模型载入方法汇总(小结)

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)

TensorFlow 模型载入方法汇总(小结) 

.meta文件保存了当前图结构

.index文件保存了当前参数名

.data文件保存了当前参数值

tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)

saver = tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。

'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
 x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
 y = Net.inference_1(x, N_CLASS=5, train=False)
 
 with tf.Session() as sess:
  # 程序前面得有 Variable 供 save or restore 才不报错
  # 否则会提示没有可保存的变量
  saver = tf.train.Saver()
 
  ckpt = tf.train.get_checkpoint_state('./model/')
  img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
  img = sess.run(tf.expand_dims(tf.image.resize_images(
   tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
  if ckpt and ckpt.model_checkpoint_path:
   print(ckpt.model_checkpoint_path)
   saver.restore(sess,'./model/model.ckpt-0')
   global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
   res = sess.run(y, feed_dict={x: img})
   print(global_step,sess.run(tf.argmax(res,1)))

2.加载图结构和参数

'''
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
 
ckpt = tf.train.get_checkpoint_state('./model/')       # 通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中
 
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)      # 载入参数,参数保存在两个文件中,不过restore会自己寻找
 
 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
 img = sess.run(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))
 
 '''
 img = sess.run(tf.expand_dims(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
 print(img)
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
     feed_dict={'Placeholder:0':img}))

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

3.简化版本

# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)
    
# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
 ckpt = tf.train.get_checkpoint_state('./model/')
 saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
 # 二进制读取模型文件
 with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
  # 新建GraphDef文件,用于临时载入模型中的图
  graph_def = tf.GraphDef()
  # GraphDef加载模型中的图
  graph_def.ParseFromString(f.read())
  # 在空白图中加载GraphDef中的图
  tf.import_graph_def(graph_def,name='')
  # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
  # 这里的张量可以直接用于session的run方法求值了
  # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
  self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
  self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的txt文件去重功能示例
Jul 07 Python
解决pandas.DataFrame.fillna 填充Nan失败的问题
Nov 06 Python
python网络编程之多线程同时接受和发送
Sep 03 Python
Django 自定义分页器的实现代码
Nov 24 Python
40行Python代码实现天气预报和每日鸡汤推送功能
Feb 27 Python
Python接口测试文件上传实例解析
May 22 Python
Keras-多输入多输出实例(多任务)
Jun 22 Python
python 从list中随机取值的方法
Nov 16 Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 Python
python操作toml文件的示例代码
Nov 27 Python
anaconda安装pytorch1.7.1和torchvision0.8.2的方法(亲测可用)
Feb 01 Python
Python OpenCV 图像平移的实现示例
Jun 04 Python
python3爬虫之设计签名小程序
Jun 19 #Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 #Python
TensorFlow数据输入的方法示例
Jun 19 #Python
深入分析python中整型不会溢出问题
Jun 18 #Python
Python登录注册验证功能实现
Jun 18 #Python
详解python3中zipfile模块用法
Jun 18 #Python
python爬取个性签名的方法
Jun 17 #Python
You might like
PHP简单系统数据添加以及数据删除模块源文件下载
2008/06/07 PHP
windows下PHP_intl.dll正确配置方法(apache2.2+php5.3.5)
2014/01/14 PHP
深入探究PHP的多进程编程方法
2015/08/18 PHP
利用PHPStorm如何开发Laravel应用详解
2017/08/30 PHP
实现laravel 插入操作日志到数据库的方法
2019/10/11 PHP
关于Aptana Studio生成自动备份文件的解决办法
2009/12/23 Javascript
自定义ExtJS控件之下拉树和下拉表格附源码
2013/10/15 Javascript
浅谈JavaScript的Polymer框架中的事件绑定
2015/07/29 Javascript
3种js实现string的substring方法
2015/11/09 Javascript
jquery判断复选框是否选中进行答题提示特效
2015/12/10 Javascript
移动端 一个简单易懂的弹出框
2016/07/06 Javascript
js如何判断是否在iframe中及防止网页被别站用iframe嵌套
2017/01/11 Javascript
Express与NodeJs创建服务器的两种方法
2017/02/06 NodeJs
在javascript中,null>=0 为真,null==0却为假,null的值详解
2017/02/22 Javascript
javascript 初学教程及五子棋小程序的简单实现
2017/07/04 Javascript
vue项目tween方法实现返回顶部的示例代码
2018/03/02 Javascript
Vue2.0 http请求以及loading展示实例
2018/03/06 Javascript
详解ajax的data参数错误导致页面崩溃
2018/04/30 Javascript
JQuery特殊效果和链式调用操作示例
2019/05/13 jQuery
通过js给网页加上水印背景实例
2019/06/17 Javascript
JavaScript This指向问题详解
2019/11/25 Javascript
解决vue项目打包上服务器显示404错误,本地没出错的问题
2020/11/03 Javascript
用python实现简单EXCEL数据统计的实例
2017/01/24 Python
python实现逻辑回归的方法示例
2017/05/02 Python
python 筛选数据集中列中value长度大于20的数据集方法
2018/06/14 Python
Python3.5文件修改操作实例分析
2019/05/01 Python
python定义类self用法实例解析
2020/01/22 Python
keras的siamese(孪生网络)实现案例
2020/06/12 Python
金宝贝童装官网:Gymboree
2016/08/31 全球购物
法人身份证明书
2014/10/08 职场文书
加强干部作风建设整改方案
2014/10/24 职场文书
2015年敬老月活动总结
2015/03/27 职场文书
考勤制度通知
2015/04/25 职场文书
小学语文教师竞聘演讲稿范文
2019/08/09 职场文书
Python基础之数据结构详解
2021/04/28 Python
JS实现数组去重的11种方法总结
2022/04/04 Javascript