python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用paramiko模块实现ssh远程登陆上传文件并执行
Jan 27 Python
Django模板变量如何传递给外部js调用的方法小结
Jul 24 Python
django rest framework 数据的查找、过滤、排序的示例
Jun 25 Python
Python 字符串与数字输出方法
Jul 16 Python
pandas把所有大于0的数设置为1的方法
Jan 26 Python
Django中的用户身份验证示例详解
Aug 07 Python
python计算二维矩形IOU实例
Jan 18 Python
python图形开发GUI库wxpython使用方法详解
Feb 14 Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 Python
Python实现EM算法实例代码
Oct 04 Python
python Tornado框架的使用示例
Oct 19 Python
Python+OpenCV实现在图像上绘制矩形
Mar 21 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
php模块memcache和memcached区别分析
2011/06/14 PHP
php中return的用法实例分析
2015/02/28 PHP
详解PHP+AJAX无刷新分页实现方法
2015/11/03 PHP
PHP单例模式与工厂模式详解
2017/08/29 PHP
PHP实现深度优先搜索算法(DFS,Depth First Search)详解
2017/09/16 PHP
PHP实现微信公众号验证Token的示例代码
2019/12/16 PHP
GRID拖拽行的实例代码
2013/07/18 Javascript
js中function()使用方法
2013/12/24 Javascript
Jquery倒计时源码分享
2014/05/16 Javascript
nodejs 实现模拟form表单上传文件
2014/07/14 NodeJs
javascript中sort()的用法实例分析
2015/01/30 Javascript
js实现页面跳转的几种方法小结
2016/05/16 Javascript
模拟javascript中的sort排序(简单实例)
2016/08/17 Javascript
vue路由前进后退动画效果的实现代码
2018/12/10 Javascript
Vue自定义属性实例分析
2019/02/23 Javascript
微信小程序云开发之数据库操作
2019/05/18 Javascript
举例讲解Django中数据模型访问外键值的方法
2015/07/21 Python
Python使用multiprocessing实现一个最简单的分布式作业调度系统
2016/03/14 Python
Python数据类型详解(一)字符串
2016/05/08 Python
Python获取本机所有网卡ip,掩码和广播地址实例代码
2018/01/22 Python
python利用7z批量解压rar的实现
2019/08/07 Python
Python基于Twilio及腾讯云实现国际国内短信接口
2020/06/18 Python
HTML5到底会有什么发展?HTML5的前景展望
2015/07/07 HTML / CSS
世界上第一个水枕头:Mediflow
2018/12/06 全球购物
一道SQL面试题
2012/12/31 面试题
会计自我鉴定范文
2013/10/06 职场文书
大四学生思想汇报
2014/01/13 职场文书
幼儿园教师自我鉴定
2014/03/20 职场文书
广告宣传策划方案
2014/05/21 职场文书
群众路线教育实践活动思想汇报(2014特荐篇)
2014/09/16 职场文书
盗窃罪辩护词范文
2015/05/21 职场文书
合作意向书范本
2019/04/17 职场文书
Python中Selenium对Cookie的操作方法
2021/07/09 Python
vue 数字翻牌器动态加载数据
2022/04/20 Vue.js
解决Git推送错误non-fast-forward的方法
2022/06/25 Servers
win10系统xps文件怎么打开?win10打开xps文件的两种操作方法
2022/07/23 数码科技