python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用sleep()方法操作时间的教程
May 22 Python
python实现按行切分文本文件的方法
Apr 18 Python
python爬虫获取京东手机图片的图文教程
Dec 29 Python
Python 保存矩阵为Excel的实现方法
Jan 28 Python
基于Python的PIL库学习详解
May 10 Python
python之pyqt5通过按钮改变Label的背景颜色方法
Jun 13 Python
在PyCharm的 Terminal(终端)切换Python版本的方法
Aug 02 Python
详解Django admin高级用法
Nov 06 Python
tensorflow 固定部分参数训练,只训练部分参数的实例
Jan 20 Python
python3.7调试的实例方法
Jul 21 Python
Python中的matplotlib绘制百分比堆叠柱状图,并为每一个类别设置不同的填充图案
Apr 20 Python
python画条形图的具体代码
Apr 20 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
php面向对象全攻略 (一) 面向对象基础知识
2009/09/30 PHP
深入PHP运行环境配置的详解
2013/06/04 PHP
php获取是星期几的的一些常用姿势
2019/12/15 PHP
如何实现浏览器上的右键菜单
2006/07/10 Javascript
在视频前插入广告
2006/11/20 Javascript
JQuery 初体验(建议学习jquery)
2009/04/25 Javascript
jQuery中$.get、$.post、$.getJSON和$.ajax的用法详解
2014/11/19 Javascript
jQuery结合CSS制作漂亮的select下拉菜单
2015/05/03 Javascript
jquery实现带渐变淡入淡出并向右依次展开的多级菜单效果实例
2015/08/22 Javascript
jquery mobile 移动web(5)
2015/12/20 Javascript
简单总结JavaScript中的String字符串类型
2016/05/26 Javascript
bootstrap导航条实现代码
2016/12/28 Javascript
聊聊Vue.js的template编译的问题
2017/10/09 Javascript
Bootstrap 中data-[*] 属性的整理
2018/03/13 Javascript
JavaScript实现写入文件到本地的方法【基于FileSaver.js插件】
2018/03/15 Javascript
对Angular中单向数据流的深入理解
2018/03/31 Javascript
jQuery操作cookie的示例代码
2019/06/05 jQuery
简单使用webpack打包文件的实现
2019/10/29 Javascript
vue实现图片上传预览功能
2019/12/23 Javascript
js实现整体缩放页面适配移动端
2020/03/31 Javascript
Python实现从url中提取域名的几种方法
2014/09/26 Python
python实现多线程的方式及多条命令并发执行
2016/06/07 Python
Python实现删除列表中满足一定条件的元素示例
2017/06/12 Python
python matplotlib画图实例代码分享
2017/12/27 Python
深入浅析Python传值与传址
2018/07/10 Python
Python 中PyQt5 点击主窗口弹出另一个窗口的实现方法
2019/07/04 Python
python使用opencv在Windows下调用摄像头实现解析
2019/11/26 Python
Django admin 实现search_fields精确查询实例
2020/03/30 Python
TensorFlow-gpu和opencv安装详细教程
2020/06/30 Python
python集合能干吗
2020/07/19 Python
报关简历自我评价怎么写
2013/09/19 职场文书
大四学生思想汇报
2014/01/13 职场文书
运动会解说词50字
2014/01/18 职场文书
《一个中国孩子的呼声》教学反思
2014/02/12 职场文书
安全生产承诺书范文
2014/05/22 职场文书
Java中的Kotlin 内部类原理
2022/06/16 Java/Android