tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python threading模块操作多线程介绍
Apr 08 Python
Python判断文本中消息重复次数的方法
Apr 27 Python
人脸识别经典算法一 特征脸方法(Eigenface)
Mar 13 Python
python读取文本中的坐标方法
Oct 14 Python
windows下 兼容Python2和Python3的解决方法
Dec 05 Python
python将txt文件读取为字典的示例
Dec 22 Python
用python 实现在不确定行数情况下多行输入方法
Jan 28 Python
python cv2读取rtsp实时码流按时生成连续视频文件方式
Dec 25 Python
python 基于wx实现音乐播放
Nov 24 Python
celery在python爬虫中定时操作实例讲解
Nov 27 Python
pycharm如何设置官方中文(如何汉化)
Dec 29 Python
图文详解matlab原始处理图像几何变换
Jul 09 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
SWFUpload与CI不能正确上传识别文件MIME类型解决方法分享
2011/04/18 PHP
PHP file_get_contents函数读取远程数据超时的解决方法
2015/05/13 PHP
php实现表单多按钮提交action的处理方法
2015/10/24 PHP
thinkphp配置文件路径的实现方法
2016/08/30 PHP
CheckBox 如何实现全选?
2006/06/23 Javascript
js获取当前select 元素值的代码
2010/04/19 Javascript
图片上传判断及预览脚本的效果实例
2013/08/07 Javascript
js实现固定显示区域内自动缩放图片的方法
2015/07/18 Javascript
vue组件学习教程
2017/09/09 Javascript
详解node nvm进行node多版本管理
2017/10/21 Javascript
从零开始封装自己的自定义Vue组件
2018/10/09 Javascript
使用koa2创建web项目的方法步骤
2019/03/12 Javascript
jquery 键盘事件 keypress() keydown() keyup()用法总结
2019/10/23 jQuery
浅谈JavaScript节流和防抖函数
2020/08/25 Javascript
vuex刷新后数据丢失的解决方法
2020/10/18 Javascript
[01:03:47]VP vs NewBee Supermajor 胜者组 BO3 第一场 6.5
2018/06/06 DOTA
Python魔术方法详解
2015/02/14 Python
python使用folium库绘制地图点击框
2018/09/21 Python
PyQt5 QListWidget选择多项并返回的实例
2019/06/17 Python
如何基于python生成list的所有的子集
2019/11/11 Python
python变量的作用域是什么
2020/05/26 Python
Python新手如何进行闭包时绑定变量操作
2020/05/29 Python
CSS3 创建网页动画实现弹跳球动效果
2018/10/30 HTML / CSS
网络维护中文求职信
2014/01/03 职场文书
简历自我评价怎么写呢?
2014/01/06 职场文书
十月份红领巾广播稿
2014/01/22 职场文书
社区党总支书记先进事迹材料
2014/01/24 职场文书
公证委托书模板
2014/04/03 职场文书
我爱我家教学反思
2014/05/01 职场文书
2014年公务员退休工资改革方案
2014/10/01 职场文书
四风查摆问题自查报告
2014/10/10 职场文书
党员三严三实对照检查材料
2014/10/13 职场文书
房地产置业顾问岗位职责
2015/04/11 职场文书
李强为自己工作观后感
2015/06/11 职场文书
CSS3通过var()和calc()函数实现动画特效
2021/03/30 HTML / CSS
升级 Win11 还是坚守 Win10?微软 Win11 新系统缺失功能大盘点
2022/04/05 数码科技