python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python判断文件和文件夹是否存在的方法
May 21 Python
在Django中创建第一个静态视图
Jul 15 Python
python中OrderedDict的使用方法详解
May 05 Python
python之matplotlib学习绘制动态更新图实例代码
Jan 23 Python
使用11行Python代码盗取了室友的U盘内容
Oct 23 Python
Scrapy框架爬取Boss直聘网Python职位信息的源码
Feb 22 Python
局域网内python socket实现windows与linux间的消息传送
Apr 19 Python
Python API 自动化实战详解(纯代码)
Jun 11 Python
python主线程与子线程的结束顺序实例解析
Dec 17 Python
tensorflow模型转ncnn的操作方式
May 25 Python
Python Selenium实现无可视化界面过程解析
Aug 25 Python
python实现股票历史数据可视化分析案例
Jun 10 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
PHP 日志缩略名的创建函数代码
2010/05/26 PHP
php中设置index.php文件为只读的方法
2013/02/06 PHP
php常用Stream函数集介绍
2013/06/24 PHP
学习php开源项目的源码指南
2014/12/21 PHP
php编写的抽奖程序中奖概率算法
2015/05/14 PHP
php使用CURL不依赖COOKIEJAR获取COOKIE的方法
2015/06/17 PHP
javascript 动态参数判空操作
2008/12/22 Javascript
jquery json 实例代码
2010/12/02 Javascript
js使浏览器窗口最大化实现代码(适用于IE)
2013/08/07 Javascript
简单的两种Extjs formpanel加载数据的方式
2013/11/09 Javascript
在JavaScript中处理时间之getHours()方法的使用
2015/06/10 Javascript
基于javascript实现样式清新图片轮播特效
2016/03/30 Javascript
js导出excel文件的简洁方法(推荐)
2016/11/02 Javascript
JS实现DIV高度自适应窗口示例
2017/02/16 Javascript
详解vuelidate 对于vueJs2.0的验证解决方案
2017/03/09 Javascript
angularjs 获取默认选中的单选按钮的value方法
2018/02/28 Javascript
基于vue-cli vue-router搭建底部导航栏移动前端项目
2018/02/28 Javascript
学习JS中的DOM节点以及操作
2018/04/30 Javascript
node中的session的具体使用
2018/09/14 Javascript
解决vue动态为数据添加新属性遇到的问题
2018/09/18 Javascript
基于python脚本实现软件的注册功能(机器码+注册码机制)
2016/10/09 Python
python基础练习之几个简单的游戏
2017/11/10 Python
Pycharm 操作Django Model的简单运用方法
2018/05/23 Python
解决jupyter notebook 出现In[*]的问题
2020/04/13 Python
在pycharm中创建django项目的示例代码
2020/05/28 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
2020/06/28 Python
详解Html5页面实现下载文件(apk、txt等)的三种方式
2018/10/22 HTML / CSS
SportsDirect.com新加坡:英国第一体育零售商
2019/03/30 全球购物
莱德杯高尔夫欧洲官方商店:Ryder Cup Shop
2019/08/14 全球购物
中专生毕业个人鉴定
2014/02/26 职场文书
眼镜促销方案
2014/03/15 职场文书
2014教师党员个人自我评议
2014/09/20 职场文书
2014年科普工作总结
2014/12/06 职场文书
预备党员转正意见
2015/06/01 职场文书
Python-typing: 类型标注与支持 Any类型详解
2021/05/10 Python
Spring Cloud Netflix 套件中的负载均衡组件 Ribbon
2022/04/13 Java/Android