python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python用字典统计单词或汉字词个数示例
Apr 22 Python
详解在Python的Django框架中创建模板库的方法
Jul 20 Python
使用 Python 实现微信公众号粉丝迁移流程
Jan 03 Python
Python解决八皇后问题示例
Apr 22 Python
Django框架使用mysql视图操作示例
May 15 Python
Python生成指定数量的优惠码实操内容
Jun 18 Python
获取django框架orm query执行的sql语句实现方法分析
Jun 20 Python
python中必要的名词解释
Nov 20 Python
Python使用uuid库生成唯一标识ID
Feb 12 Python
浅谈多卡服务器下隐藏部分 GPU 和 TensorFlow 的显存使用设置
Jun 30 Python
Python3爬虫发送请求的知识点实例
Jul 30 Python
Python如何将将模块分割成多个文件
Aug 04 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
php 仿Comsenz安装效果代码打包提供下载
2010/05/09 PHP
本地计算机无法启动Apache故障处理
2014/08/08 PHP
Laravel 5框架学习之向视图传送数据
2015/04/08 PHP
功能强大的PHP发邮件类
2016/08/29 PHP
php正则提取html图片(img)src地址与任意属性的方法
2017/02/08 PHP
PHP简单实现解析xml为数组的方法
2018/05/02 PHP
PHP实现用session来实现记录用户登陆信息
2018/10/15 PHP
js以对象为索引的关联数组
2010/07/04 Javascript
查看大图功能代码jquery版
2013/11/05 Javascript
jquery插件jquery倒计时插件分享
2013/12/27 Javascript
document.execCommand()的用法小结
2014/01/08 Javascript
jQuery 鼠标经过(hover)事件的延时处理示例
2014/04/14 Javascript
构造函数+原型模式构造js自定义对象(最通用)
2014/05/12 Javascript
JS 在指定数组中随机取出N个不重复的数据
2014/06/10 Javascript
JQuery实现的购物车功能(可以减少或者添加商品并自动计算价格)
2015/01/13 Javascript
JS中prototype的用法实例分析
2015/03/19 Javascript
JavaScript轮播图简单制作方法
2017/02/20 Javascript
详解vue父子模版嵌套案例
2017/03/04 Javascript
详解webpack4多入口、多页面项目构建案例
2018/05/25 Javascript
Angular4.x Event (DOM事件和自定义事件详解)
2018/10/09 Javascript
微信小程序实现下拉菜单切换效果
2020/03/30 Javascript
js观察者模式的弹幕案例
2020/11/23 Javascript
python利用datetime模块计算时间差
2015/08/04 Python
python 实时遍历日志文件
2016/04/12 Python
Python实现Youku视频批量下载功能
2017/03/14 Python
Python随机读取文件实现实例
2017/05/25 Python
Python实现的朴素贝叶斯分类器示例
2018/01/06 Python
Python函数的定义方式与函数参数问题实例分析
2019/12/26 Python
Python装饰器实现方法及应用场景详解
2020/03/26 Python
大学生学习生活的自我评价
2013/11/01 职场文书
大二学生学习个人自我评价
2014/01/19 职场文书
家具商场的活动方案
2014/08/16 职场文书
捐书倡议书
2014/08/29 职场文书
2014年大学生党员评议表自我评价
2014/09/20 职场文书
网站文案策划岗位职责
2015/04/14 职场文书
圣诞晚会主持词
2015/07/01 职场文书