python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python简单程序读取串口信息的方法
Mar 13 Python
python中threading超线程用法实例分析
May 16 Python
浅谈Python基础之I/O模型
May 11 Python
Python3实战之爬虫抓取网易云音乐的热门评论
Oct 09 Python
Python中turtle作图示例
Nov 15 Python
TensorFlow saver指定变量的存取
Mar 10 Python
python MySQLdb使用教程详解
Mar 20 Python
python多线程与多进程及其区别详解
Aug 08 Python
python flask搭建web应用教程
Nov 19 Python
python数据库编程 ODBC方式实现通讯录
Mar 27 Python
在python中求分布函数相关的包实例
Apr 15 Python
利用Matlab绘制各类特殊图形的实例代码
Jul 16 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
将时间以距今多久的形式表示,PHP,js双版本
2012/09/25 PHP
探讨如何使用SimpleXML函数来加载和解析XML文档
2013/06/07 PHP
php中0,null,empty,空,false,字符串关系的详细介绍
2013/06/20 PHP
php堆排序(heapsort)练习
2013/11/13 PHP
PHP实现使用优酷土豆视频地址获取swf播放器分享地址
2014/06/05 PHP
thinkphp模板赋值与替换实例简述
2014/11/24 PHP
PHP实现在线阅读PDF文件的方法
2015/06/17 PHP
WordPres对前端页面调试时的两个PHP函数使用小技巧
2015/12/22 PHP
一个判断email合法性的函数[非正则]
2008/12/09 Javascript
Jquery 自定义动画概述及示例
2013/03/29 Javascript
解决angular的$http.post()提交数据时后台接收不到参数值问题的方法
2015/12/10 Javascript
基于JavaScript实现带缩略图的轮播效果
2017/01/12 Javascript
jquery请求servlet实现ajax异步请求的示例
2017/06/03 jQuery
JS中LocalStorage与SessionStorage五种循序渐进的使用方法
2017/07/12 Javascript
Angularjs cookie 操作实例详解
2017/09/27 Javascript
详解Vue路由钩子及应用场景(小结)
2017/11/07 Javascript
微信小程序ajax实现请求服务器数据及模版遍历数据功能示例
2017/12/15 Javascript
JS解惑之Object中的key是有序的么
2019/05/06 Javascript
浅析Angular 实现一个repeat指令的方法
2019/07/21 Javascript
[02:19]2014DOTA2国际邀请赛 专访820少年们一起去追梦吧
2014/07/14 DOTA
Python中3种内建数据结构:列表、元组和字典
2014/11/30 Python
python从入门到精通(DAY 2)
2015/12/20 Python
python中如何正确使用正则表达式的详细模式(Verbose mode expression)
2017/11/08 Python
Python二叉树的定义及常用遍历算法分析
2017/11/24 Python
Python解决pip install时出现的Could not fetch URL问题
2019/08/01 Python
pandas数据处理之绘图的实现
2020/06/15 Python
详解anaconda离线安装pytorchGPU版
2020/09/08 Python
苹果音乐订阅:Apple Music
2018/08/02 全球购物
空字符串(“”)和null的区别
2012/11/13 面试题
银行会计财务工作个人的自我评价
2013/10/29 职场文书
退伍老兵事迹材料
2014/01/31 职场文书
2015年度电厂个人工作总结
2015/05/13 职场文书
2015年税务稽查工作总结
2015/05/26 职场文书
2016年寒假家长评语
2015/10/10 职场文书
浅谈Redis主从复制以及主从复制原理
2021/05/29 Redis
Python的这些库,你知道多少?
2021/06/09 Python