python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
寻找网站后台地址的python脚本
Sep 01 Python
Python中内建函数的简单用法说明
May 05 Python
Python实现Windows和Linux之间互相传输文件(文件夹)的方法
May 08 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 Python
Python实现的微信好友数据分析功能示例
Jun 21 Python
python找出完数的方法
Nov 12 Python
浅谈Python接口对json串的处理方法
Dec 19 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
Python判断telnet通不通的实例
Jan 26 Python
python爬虫学习笔记之pyquery模块基本用法详解
Apr 09 Python
python进行参数传递的方法
May 12 Python
Pycharm操作Git及GitHub的步骤详解
Oct 27 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
PHP制作3D扇形统计图以及对图片进行缩放操作实例
2014/10/23 PHP
php采用ajax数据提交post与post常见方法总结
2014/11/10 PHP
php项目开发中用到的快速排序算法分析
2016/06/25 PHP
Thinkphp框架中D方法与M方法的区别
2016/12/23 PHP
jquery按回车提交数据的代码示例
2013/11/05 Javascript
javascript事件冒泡详解和捕获、阻止方法
2014/04/12 Javascript
JavaScript获取路径设计源码
2014/05/22 Javascript
JavaScript Serializer序列化时间处理示例
2014/07/31 Javascript
javascript html5 canvas实现可拖动省份的中国地图
2016/03/11 Javascript
JavaScript 弹出子窗体并返回结果到父窗体的实现代码
2016/05/28 Javascript
Javascript实现图片加载从模糊到清晰显示的方法
2016/06/21 Javascript
引用jquery框架后出错的解决方法
2016/08/09 Javascript
vue2.0 keep-alive最佳实践
2017/07/06 Javascript
jQuery完成表单验证的实例代码(纯代码)
2017/09/30 jQuery
javascript实现Emrips反质数枚举的示例代码
2017/12/06 Javascript
angular2 ng2-file-upload上传示例代码
2018/08/23 Javascript
node.js中express模块创建服务器和http模块客户端发请求
2019/03/06 Javascript
vue实现滑动到底部加载更多效果
2020/10/27 Javascript
jQuery单页面文字搜索插件jquery.fullsearch.js的使用方法
2020/02/04 jQuery
[03:17]史诗级大片应援2018DOTA2国际邀请赛 致敬每一位坚守遗迹的勇士
2018/07/20 DOTA
[54:57]DOTA2-DPC中国联赛定级赛 Aster vs DLG BO3第二场 1月8日
2021/03/11 DOTA
Python 获取当前所在目录的方法详解
2017/08/02 Python
Python 统计字数的思路详解
2018/05/08 Python
Python把csv数据写入list和字典类型的变量脚本方法
2018/06/15 Python
Pycharm+Python+PyQt5使用详解
2019/09/25 Python
Python3中configparser模块读写ini文件并解析配置的用法详解
2020/02/18 Python
浅谈Python中文件夹和python package包的区别
2020/06/01 Python
H5新属性audio音频和video视频的控制详解(推荐)
2016/12/09 HTML / CSS
数字漫画:comiXology
2020/06/13 全球购物
药品采购员岗位职责
2014/02/08 职场文书
幼儿老师求职信
2014/06/30 职场文书
国际政治学专业推荐信
2014/09/26 职场文书
2014年大学班级工作总结
2014/11/14 职场文书
机关干部纪律作风整顿心得体会
2016/01/23 职场文书
MySQL运行报错:“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法
2022/06/14 MySQL
Redis+AOP+自定义注解实现限流
2022/06/28 Redis