python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python制作钉钉加密/解密工具
Dec 07 Python
Django查询数据库的性能优化示例代码
Sep 24 Python
python字典操作实例详解
Nov 16 Python
Python实现简单石头剪刀布游戏
Jan 20 Python
在python中对变量判断是否为None的三种方法总结
Jan 23 Python
python内存动态分配过程详解
Jul 15 Python
深入学习python多线程与GIL
Aug 26 Python
Python lambda表达式filter、map、reduce函数用法解析
Sep 11 Python
python如何实现单链表的反转
Feb 10 Python
pyspark给dataframe增加新的一列的实现示例
Apr 24 Python
python 实现汉诺塔游戏
Nov 28 Python
python FTP编程基础入门
Feb 27 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
基于qmail的完整WEBMAIL解决方案安装详解
2006/10/09 PHP
用PHP的ob_start() 控制您的浏览器cache
2009/08/03 PHP
浏览器预览PHP文件时顶部出现空白影响布局分析原因及解决办法
2013/01/11 PHP
php中session使用示例
2014/03/29 PHP
windows7下php开发环境搭建图文教程
2015/01/06 PHP
PHP简单生成缩略图相册的方法
2015/07/29 PHP
php安装ssh2扩展的方法【Linux平台】
2016/07/20 PHP
PHP删除二维数组中相同元素及数组重复值的方法示例
2017/05/05 PHP
Yii2框架实现利用mpdf创建pdf文件功能示例
2019/02/08 PHP
javascript 模式设计之工厂模式学习心得
2010/04/27 Javascript
JavaScript脚本判断蜘蛛来源的方法
2015/09/22 Javascript
JavaScript操作URL的相关内容集锦
2015/10/29 Javascript
js中flexible.js实现淘宝弹性布局方案
2020/06/23 Javascript
详解AngularJS中自定义过滤器
2015/12/28 Javascript
浅谈JS之iframe中的窗口
2016/09/13 Javascript
原生JS改变透明度实现轮播效果
2017/03/24 Javascript
angularjs使用gulp-uglify压缩后执行报错的解决方法
2018/03/07 Javascript
关于js陀螺仪的理解分析
2019/04/11 Javascript
vue中全局路由守卫中替代this操作(this.$store/this.$vux)
2020/07/24 Javascript
解决VUE项目使用Element-ui 下拉组件的验证失效问题
2020/11/07 Javascript
Python实现信用卡系统(支持购物、转账、存取钱)
2016/06/24 Python
Python字典及字典基本操作方法详解
2018/01/30 Python
django的登录注册系统的示例代码
2018/05/14 Python
Python 实现异步调用函数的示例讲解
2018/10/14 Python
用Python实现读写锁的示例代码
2018/11/05 Python
python多进程下实现日志记录按时间分割
2019/07/22 Python
python 计算积分图和haar特征的实例代码
2019/11/20 Python
pytorch 自定义参数不更新方式
2020/01/06 Python
微信小程序canvas实现水平、垂直居中效果
2020/02/05 HTML / CSS
Daniel Wellington官方海外旗舰店:丹尼尔惠灵顿DW手表
2018/02/22 全球购物
优秀班组申报材料
2014/12/25 职场文书
教育实习指导教师评语
2014/12/31 职场文书
行政主管岗位职责
2015/02/03 职场文书
Java中PriorityQueue实现最小堆和最大堆的用法
2021/06/27 Java/Android
CSS实现隐藏搜索框功能(动画正反向序列)
2021/07/21 HTML / CSS
MySQL数据库查询之多表查询总结
2022/08/05 MySQL