python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python封装对象实现时间效果
Apr 23 Python
在Python3中初学者应会的一些基本的提升效率的小技巧
Mar 31 Python
python实现分析apache和nginx日志文件并输出访客ip列表的方法
Apr 04 Python
详解Python网络爬虫功能的基本写法
Jan 28 Python
Python执行时间的计算方法小结
Mar 17 Python
从CentOS安装完成到生成词云python的实例
Dec 01 Python
Python(Django)项目与Apache的管理交互的方法
May 16 Python
使用tqdm显示Python代码执行进度功能
Dec 08 Python
python名片管理系统开发
Jun 18 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
Python中super().__init__()测试以及理解
Dec 06 Python
Python使用MapReduce进行简单的销售统计
Apr 22 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
phpmyadmin 常用选项设置详解版
2010/03/07 PHP
php下使用iconv需要注意的问题
2010/11/20 PHP
php自动获取字符串编码函数mb_detect_encoding
2011/05/31 PHP
推荐25款php中非常有用的类库
2014/09/29 PHP
PHP rsa加密解密使用方法
2015/04/27 PHP
PHP的时间戳与具体时间转化的简单实现
2016/06/13 PHP
php利用云片网实现短信验证码功能的示例代码
2017/11/18 PHP
JObj预览一个JS的框架
2008/03/13 Javascript
javascript setTimeout和setInterval计时的区别详解
2013/06/21 Javascript
javascript 控制input只允许输入的各种指定内容
2014/06/19 Javascript
使用AngularJS和PHP的Laravel实现单页评论的方法
2015/06/19 Javascript
JS+CSS实现大气清新的滑动菜单效果代码
2015/10/22 Javascript
Easyui form combobox省市区三级联动
2016/01/13 Javascript
DIV随滚动条滚动而滚动的实现代码【推荐】
2016/04/12 Javascript
jQuery实现所有验证通过方可提交的表单验证
2017/11/21 jQuery
vue.js系列中的vue-fontawesome使用
2018/02/10 Javascript
详解使用element-ui table组件的筛选功能的一个小坑
2018/11/02 Javascript
基于vue实现滚动条滚动到指定位置对应位置数字进行tween特效
2019/04/18 Javascript
详解Vue 匿名、具名和作用域插槽的使用方法
2019/04/22 Javascript
vue 中的动态传参和query传参操作
2020/11/09 Javascript
[46:43]DOTA2上海特级锦标赛D组小组赛#1 EG VS COL第三局
2016/02/28 DOTA
python使用点操作符访问字典(dict)数据的方法
2015/03/16 Python
python中argparse模块用法实例详解
2015/06/03 Python
Scrapy使用的基本流程与实例讲解
2018/10/21 Python
营销与策划个人求职信
2013/09/22 职场文书
预备党员转正考核材料
2014/06/03 职场文书
公司副总经理任命书
2014/06/05 职场文书
求职意向书
2014/07/29 职场文书
2014年平安夜寄语
2014/12/08 职场文书
小学生手册家长意见
2015/06/03 职场文书
小学语文新课改心得体会
2016/01/22 职场文书
股东出资协议书
2016/03/21 职场文书
2019年度行政文员工作计划范本!
2019/07/04 职场文书
生鲜超市—未来中国最具有潜力零售业态
2019/08/02 职场文书
XX部保密工作制度范本
2019/08/27 职场文书
我国拿下天问一号火星着陆区附近 22 个地理实体命名:平乐、西柏坡、古田、漠河等
2022/04/29 数码科技