python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用SQLite的简单教程
Apr 29 Python
用Python一键搭建Http服务器的方法
Jun 01 Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 Python
详解Python:面向对象编程
Apr 10 Python
Python切片操作去除字符串首尾的空格
Apr 22 Python
Python实现爬取亚马逊数据并打印出Excel文件操作示例
May 16 Python
ipad上运行python的方法步骤
Oct 12 Python
python 两个数据库postgresql对比
Oct 21 Python
图解python全局变量与局部变量相关知识
Nov 02 Python
Django使用Celery加redis执行异步任务的实例内容
Feb 20 Python
python如何使用代码运行助手
Jul 03 Python
pandas使用函数批量处理数据(map、apply、applymap)
Nov 27 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
php执行sql语句的写法
2009/03/10 PHP
PHP 采集程序中常用的函数
2009/12/09 PHP
php判断正常访问和外部访问的示例
2014/02/10 PHP
php 指定范围内多个随机数代码实例
2016/07/18 PHP
关于 Laravel Redis 多个进程同时取队列问题详解
2017/12/25 PHP
父子窗体间传递JSON格式的数据的代码
2010/12/25 Javascript
深入了解Node.js中的一些特性
2014/09/25 Javascript
jquery中append()与appendto()用法分析
2014/11/14 Javascript
D3.js中data(), enter() 和 exit()的问题详解
2015/08/17 Javascript
js实现简洁的滑动门菜单(选项卡)效果代码
2015/09/04 Javascript
Angular.js 实现数字转换汉字实例代码
2016/07/14 Javascript
JS无缝滚动效果实现方法分析
2016/12/21 Javascript
angularJS 发起$http.post和$http.get请求的实现方法
2017/05/18 Javascript
jQuery插件select2利用ajax高效查询大数据列表(可搜索、可分页)
2017/05/19 jQuery
浅谈Vue组件及组件的注册方法
2018/08/24 Javascript
vue2中,根据list的id进入对应的详情页并修改title方法
2018/08/24 Javascript
vue中接口域名配置为全局变量的实现方法
2018/09/20 Javascript
浅谈VUE-CLI脚手架热更新太慢的原因和解决方法
2018/09/28 Javascript
详解vue-property-decorator使用手册
2019/07/29 Javascript
Python写的英文字符大小写转换代码示例
2015/03/06 Python
matplotlib绘图实例演示标记路径
2018/01/23 Python
Python实现对一个函数应用多个装饰器的方法示例
2018/02/09 Python
解析python实现Lasso回归
2019/09/11 Python
关于Flask项目无法使用公网IP访问的解决方式
2019/11/19 Python
jupyter notebook插入本地图片的实现
2020/04/13 Python
浅析Python requests 模块
2020/10/09 Python
意大利辅助药品、药物和补品在线销售:FarmaEurope
2020/04/29 全球购物
销售文员岗位职责
2013/11/29 职场文书
班主任工作经验材料
2014/02/02 职场文书
九年级数学教学反思
2014/02/02 职场文书
机械设计及其自动化专业求职信
2014/06/09 职场文书
学校法制宣传月活动总结
2014/07/03 职场文书
募捐感谢信
2015/01/22 职场文书
工厂仓管员岗位职责
2015/04/01 职场文书
SpringCloud项目如何解决log4j2漏洞
2022/04/10 Java/Android
nginx sticky实现基于cookie负载均衡示例详解
2022/12/24 Servers