tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python对Excel进行读写操作
Mar 30 Python
spyder常用快捷键(分享)
Jul 19 Python
python数据结构链表之单向链表(实例讲解)
Jul 25 Python
Scrapy框架CrawlSpiders的介绍以及使用详解
Nov 29 Python
python 获取url中的参数列表实例
Dec 18 Python
python networkx 包绘制复杂网络关系图的实现
Jul 10 Python
Django缓存系统实现过程解析
Aug 02 Python
python with (as)语句实例详解
Feb 04 Python
使用tensorflow实现VGG网络,训练mnist数据集方式
May 26 Python
8种常用的Python工具
Aug 05 Python
python中如何使用虚拟环境
Oct 14 Python
Python Tkinter实例——模拟掷骰子
Oct 24 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
生成静态页面的PHP类
2006/07/15 PHP
php 随机数的产生、页面跳转、件读写、文件重命名、switch语句
2009/08/07 PHP
php中的观察者模式
2010/03/24 PHP
PHP读取CURL模拟登录时生成Cookie文件的方法
2014/11/04 PHP
php的api数据接口书写实例(推荐)
2016/09/22 PHP
PHP定义字符串的四种方式详解
2018/02/06 PHP
实例讲解PHP表单验证功能
2019/02/15 PHP
JavaScript中判断函数是new还是()调用的区别说明
2011/04/07 Javascript
jQuery学习笔记 操作jQuery对象 文档处理
2012/09/19 Javascript
jquery的$getjson调用并获取远程的JSON字符串问题
2012/12/10 Javascript
如何用jquery控制表格奇偶行及活动行颜色
2014/04/20 Javascript
纯js实现重发验证码按钮倒数功能
2015/04/21 Javascript
jquery实现表单输入时提示文字滑动向上效果
2015/08/10 Javascript
JavaScript 模块的循环加载实现方法
2015/12/13 Javascript
JS/jQ实现免费获取手机验证码倒计时效果
2016/06/13 Javascript
JS中关于事件处理函数名后面是否带括号的问题
2016/11/16 Javascript
基于vue的下拉刷新指令和滚动刷新指令
2016/12/23 Javascript
vue.js框架实现表单排序和分页效果
2017/08/09 Javascript
详解在vue-test-utils中mock全局对象
2018/11/07 Javascript
angular中如何绑定iframe中src的方法
2019/02/01 Javascript
Vue自定义多选组件使用详解
2020/09/08 Javascript
npm ci命令的基本使用方法
2020/09/20 Javascript
python 切片和range()用法说明
2013/03/24 Python
python计算最大优先级队列实例
2013/12/18 Python
python获取局域网占带宽最大3个ip的方法
2015/07/09 Python
Python中集合的内建函数和内建方法学习教程
2015/08/19 Python
Python中在for循环中嵌套使用if和else语句的技巧
2016/06/20 Python
使用python对excle和json互相转换的示例
2018/10/23 Python
python3用PIL把图片转换为RGB图片的实例
2019/07/04 Python
Python执行时间的几种计算方法
2020/07/31 Python
马来西亚最大的在线隐形眼镜商店:MrLens
2019/03/27 全球购物
土耳其玩具商店:Toyzz Shop
2019/08/02 全球购物
女大学生个人求职信
2013/12/09 职场文书
村主任“四风”问题个人对照检查材料思想汇报
2014/10/02 职场文书
专业技术人员年度考核评语
2014/12/31 职场文书
HTML中table表格拆分合并(colspan、rowspan)
2021/04/07 HTML / CSS