python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
Jun 09 Python
Python调用C语言开发的共享库方法实例
Mar 18 Python
Python下的twisted框架入门指引
Apr 15 Python
用TensorFlow实现lasso回归和岭回归算法的示例
May 02 Python
Python3爬虫爬取英雄联盟高清桌面壁纸功能示例【基于Scrapy框架】
Dec 05 Python
python 获取毫秒数,计算调用时长的方法
Feb 20 Python
python dlib人脸识别代码实例
Apr 04 Python
python basemap 画出经纬度并标定的实例
Jul 09 Python
Python使用mongodb保存爬取豆瓣电影的数据过程解析
Aug 14 Python
Python使用matplotlib绘制圆形代码实例
May 27 Python
python递归函数用法详解
Oct 26 Python
python tkinter实现定时关机
Apr 21 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
PHP中单引号与双引号的区别分析
2014/08/19 PHP
PHP实现原生态图片上传封装类方法
2016/11/08 PHP
PHP文件与目录操作示例
2016/12/24 PHP
PHP7扩展开发之hello word实现方法详解
2018/01/15 PHP
js返回前一页刷新本页重载页面
2014/07/29 Javascript
JS实现清除指定cookies的方法
2014/09/20 Javascript
javascript简单实现图片预加载
2014/12/03 Javascript
JavaScript判断一个字符串是否包含指定子字符串的方法
2015/03/18 Javascript
JS实现仿google、百度搜索框输入信息智能提示的实现方法
2015/04/20 Javascript
Bootstrap编写一个兼容主流浏览器的受众门户式风格页面
2016/07/01 Javascript
微信小程序开发之大转盘 仿天猫超市抽奖实例
2016/12/08 Javascript
详解react-router 4.0 下服务器如何配合BrowserRouter
2017/12/29 Javascript
详解微信小程序实现WebSocket心跳重连
2018/07/31 Javascript
JavaScript中变量、指针和引用功能与操作示例
2018/08/04 Javascript
浅谈在不使用ssr的情况下解决Vue单页面SEO问题(2)
2018/11/08 Javascript
js中arguments对象的深入理解
2019/05/14 Javascript
在Vue.js中使用TypeScript的方法
2020/03/19 Javascript
Vue使用轮询定时发送请求代码
2020/08/10 Javascript
[02:40]2014DOTA2 国际邀请赛中国区预选赛 四大豪门抵达华西村
2014/05/23 DOTA
Python调用SQLPlus来操作和解析Oracle数据库的方法
2016/04/09 Python
利用numpy+matplotlib绘图的基本操作教程
2017/05/03 Python
python3实现TCP协议的简单服务器和客户端案例(分享)
2017/06/14 Python
基于Pandas读取csv文件Error的总结
2018/06/15 Python
Python自动生成代码 使用tkinter图形化操作并生成代码框架
2019/09/18 Python
Python 取numpy数组的某几行某几列方法
2019/10/24 Python
Python使用configparser库读取配置文件
2020/02/22 Python
Fashion Eyewear美国:英国线上设计师眼镜和太阳镜的零售商
2016/08/15 全球购物
美国著名手表网站:Timepiece
2017/11/15 全球购物
Annoushka英国官网:英国奢侈珠宝品牌
2018/10/20 全球购物
戴森西班牙官网:Dyson西班牙
2020/02/04 全球购物
离职保密承诺书
2014/05/28 职场文书
2014年学校财务工作总结
2014/12/06 职场文书
驳回起诉民事裁定书
2015/05/19 职场文书
女儿满月酒致辞
2015/07/29 职场文书
血轮眼轮回眼特效 html+css
2021/03/31 HTML / CSS
MySQL优化之如何写出高质量sql语句
2021/05/17 MySQL