python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中URL视图函数的一些高级概念介绍
Jul 20 Python
Python实现Linux命令xxd -i功能
Mar 06 Python
在 Python 应用中使用 MongoDB的方法
Jan 05 Python
Python解惑之整数比较详解
Apr 24 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
python挖矿算力测试程序详解
Jul 03 Python
Python定时任务APScheduler的实例实例详解
Jul 22 Python
pandas中DataFrame修改index、columns名的方法示例
Aug 02 Python
python字符串格式化方式解析
Oct 19 Python
基于Python计算圆周率pi代码实例
Mar 25 Python
Python爬虫+Tkinter制作一个翻译软件的示例
Feb 20 Python
给numpy.array增加维度的超简单方法
Jun 02 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
哪吒敖丙传:新人物二哥敖乙出场 小敖丙奶气十足
2020/03/08 国漫
PHP HTML代码串 截取实现代码
2009/06/29 PHP
献给php初学者(入门学习经验谈)
2010/10/12 PHP
解析Linux下Varnish缓存的配置优化
2013/06/20 PHP
非常实用的PHP常用函数汇总
2014/12/17 PHP
php返回当前日期或者指定日期是周几
2015/05/21 PHP
使用PHP编写发红包程序
2015/07/22 PHP
PHP中使用array函数新建一个数组
2015/11/19 PHP
PHP+Jquery与ajax相结合实现下拉淡出瀑布流效果【无需插件】
2016/05/06 PHP
thinkPHP批量删除的实现方法分析
2016/11/09 PHP
整理一些JavaScript的IE和火狐的兼容性注意事项
2011/03/17 Javascript
如何制作浮动广告 JavaScript制作浮动广告代码
2012/12/30 Javascript
javascript自定义startWith()和endWith()的两种方法
2013/11/11 Javascript
提交按钮的name='submit'引起的js失效问题及原因
2015/02/25 Javascript
JavaScript中关于for循环删除数组元素内容时出现的问题
2016/11/21 Javascript
利用浮层使select不可选的实现方法
2016/12/03 Javascript
javascript数组去重方法分析
2016/12/15 Javascript
HTML5+Canvas调用手机拍照功能实现图片上传(下)
2017/04/21 Javascript
微信小程序之GET请求的实例详解
2017/09/29 Javascript
微信小程序自定义单项选择器样式
2019/07/25 Javascript
微信sdk实现禁止微信分享(使用原生php实现)
2019/11/15 Javascript
[44:39]2014 DOTA2国际邀请赛中国区预选赛 NE VS CNB
2014/05/21 DOTA
[46:23]OG vs EG 2018国际邀请赛淘汰赛BO3 第一场 8.23
2018/08/24 DOTA
Python同时向控制台和文件输出日志logging的方法
2015/05/26 Python
Python中time模块和datetime模块的用法示例
2016/02/28 Python
分享python数据统计的一些小技巧
2016/07/21 Python
python读取文本中的坐标方法
2018/10/14 Python
Django实现微信小程序的登录验证功能并维护登录态
2019/07/04 Python
给 TensorFlow 变量进行赋值的方式
2020/02/10 Python
美国羊皮公司:Overland
2018/01/15 全球购物
*p++ 自增p 还是p所指向的变量
2016/07/16 面试题
会计专业自荐信
2014/06/03 职场文书
优秀毕业生的求职信
2014/07/21 职场文书
2015年银行工作总结范文
2015/04/01 职场文书
2015年度护士个人工作总结
2015/04/09 职场文书
详解MySQL主从复制及读写分离
2021/05/07 MySQL