tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python展示动态规则法用以解决重叠子问题的示例
Apr 02 Python
python连接字符串的方法小结
Jul 13 Python
Python代码解决RenderView窗口not found问题
Aug 28 Python
Python实现获取磁盘剩余空间的2种方法
Jun 07 Python
Numpy掩码式数组详解
Apr 17 Python
Python列表删除元素del、pop()和remove()的区别小结
Sep 11 Python
浅谈python量化 双均线策略(金叉死叉)
Jun 03 Python
20行Python代码实现一款永久免费PDF编辑工具的实现
Aug 27 Python
Python之字典对象的几种创建方法
Sep 30 Python
利用python进行文件操作
Dec 04 Python
Python实现机器学习算法的分类
Jun 03 Python
python 单机五子棋对战游戏
Apr 28 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
如何分别全角和半角以避免乱码
2006/10/09 PHP
php中通过正则表达式下载内容中的远程图片的函数代码
2012/01/10 PHP
关于PHP内存溢出问题的解决方法
2013/06/25 PHP
PHP对文件夹递归执行chmod命令的方法
2015/06/19 PHP
windows下apache搭建php开发环境
2015/08/27 PHP
如何实现浏览器上的右键菜单
2006/07/10 Javascript
Javascript YUI 读码日记之 YAHOO.util.Dom - Part.2 0
2008/03/22 Javascript
js常用排序实现代码
2010/12/28 Javascript
Jquery 表单验证类介绍与实例
2013/06/09 Javascript
JavaScript DOM事件(笔记)
2015/04/08 Javascript
node+express制作爬虫教程
2016/11/11 Javascript
JQuery实现动态操作表格
2017/01/11 Javascript
Vue自定义指令使用方法详解
2017/08/21 Javascript
使用 Node.js 模拟滑动拼图验证码操作的示例代码
2017/11/02 Javascript
koa2 用户注册、登录校验与加盐加密的实现方法
2019/07/22 Javascript
使用layer弹窗提交表单时判断表单是否输入为空的例子
2019/09/26 Javascript
JavaScript 作用域scope简单汇总
2019/10/23 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
Python检测网站链接是否已存在
2016/04/07 Python
python与C互相调用的方法详解
2017/07/14 Python
Python处理文本换行符实例代码
2018/02/03 Python
Python装饰器的执行过程实例分析
2018/06/04 Python
python+PyQT实现系统桌面时钟
2020/06/16 Python
Python3.5局部变量与全局变量作用域实例分析
2019/04/30 Python
Python 中list ,set,dict的大规模查找效率对比详解
2019/10/11 Python
python实现一个点绕另一个点旋转后的坐标
2019/12/04 Python
Python timeit模块的使用实践
2020/01/13 Python
Python视频编辑库MoviePy的使用
2020/04/01 Python
python Django 反向访问器的外键冲突解决
2020/05/20 Python
python让函数不返回结果的方法
2020/06/22 Python
HTML5 Canvas旋转动画的2个代码例子(一个旋转的太极图效果)
2014/04/10 HTML / CSS
京东奢侈品:全球奢侈品牌
2018/03/17 全球购物
艺术系应届生的自我评价
2013/10/19 职场文书
学前教育求职自荐信范文
2013/12/25 职场文书
如何写通讯稿
2015/07/22 职场文书
Vue OpenLayer测距功能的实现
2022/04/20 Vue.js