python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Mac下Supervisor进程监控管理工具的安装与配置
Dec 16 Python
Python3连接MySQL(pymysql)模拟转账实现代码
May 24 Python
深入理解python中的atexit模块
Mar 07 Python
Sanic框架基于类的视图用法示例
Jul 18 Python
Python线性拟合实现函数与用法示例
Dec 13 Python
python实现五子棋游戏
Jun 18 Python
python selenium 查找隐藏元素 自动播放视频功能
Jul 24 Python
Python中利用LSTM模型进行时间序列预测分析的实现
Jul 26 Python
Python爬取知乎图片代码实现解析
Sep 17 Python
如何利用Python识别图片中的文字
May 31 Python
Python面向对象实现方法总结
Aug 12 Python
详解Python中的GIL(全局解释器锁)详解及解决GIL的几种方案
Jan 29 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
全国FM电台频率大全 - 23 四川省
2020/03/11 无线电
php && 逻辑与运算符使用说明
2010/03/04 PHP
完美解决thinkphp验证码出错无法显示的方法
2014/12/09 PHP
PHP精确到毫秒秒杀倒计时实例详解
2019/03/14 PHP
解决Jquery鼠标经过不停滑动的问题
2014/03/03 Javascript
js实现简单秒表走动的时钟特效
2020/03/25 Javascript
javascript跨域总结之window.name实现的跨域数据传输
2015/11/01 Javascript
jquery UI Datepicker时间控件的使用方法(基础版)
2015/11/07 Javascript
JQuery导航菜单选择特效
2016/04/11 Javascript
基于BootStrap Metronic开发框架经验小结【四】Bootstrap图标的提取和利用
2016/05/12 Javascript
Angular中实现树形结构视图实例代码
2017/05/05 Javascript
Angular2使用Angular-CLI快速搭建工程(二)
2017/05/21 Javascript
详解angularjs利用ui-route异步加载组件
2017/05/21 Javascript
jQuery实现简单的下拉菜单导航功能示例
2017/12/07 jQuery
JavaScript树的深度优先遍历和广度优先遍历算法示例
2018/07/30 Javascript
Element input树型下拉框的实现代码
2018/12/21 Javascript
JS立即执行的匿名函数用法分析
2019/11/04 Javascript
vue 解决兄弟组件、跨组件深层次的通信操作
2020/07/27 Javascript
vue使用video插件vue-video-player详解
2020/10/23 Javascript
[00:32]2018DOTA2亚洲邀请赛出场——VP
2018/04/04 DOTA
Python Web框架Flask信号机制(signals)介绍
2015/01/01 Python
Python基于回溯法子集树模板解决全排列问题示例
2017/09/07 Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
2017/12/14 Python
python实现百度语音识别api
2018/04/10 Python
python版本的仿windows计划任务工具
2018/04/30 Python
解决Pycharm后台indexing导致不能run的问题
2019/06/27 Python
python3实现带多张图片、附件的邮件发送
2019/08/10 Python
python中有帮助函数吗
2020/06/19 Python
简单了解Django项目应用创建过程
2020/07/06 Python
Python如何将装饰器定义为类
2020/07/30 Python
施工单位安全责任书
2014/07/24 职场文书
校园环保广播稿(3篇)
2014/09/15 职场文书
小学生九一八纪念日83周年演讲稿500字
2014/09/17 职场文书
先进教师事迹材料
2014/12/16 职场文书
工程款申请报告
2015/05/15 职场文书
2015年高校保卫处工作总结
2015/07/23 职场文书