tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python判断字符串编码的简单实现方法(使用chardet)
Jul 01 Python
利用Python2下载单张图片与爬取网页图片实例代码
Dec 25 Python
Python实现学校管理系统
Jan 11 Python
对Python 网络设备巡检脚本的实例讲解
Apr 22 Python
Python处理命令行参数模块optpars用法实例分析
May 31 Python
树莓派使用python-librtmp实现rtmp推流h264的方法
Jul 22 Python
python解释器spython使用及原理解析
Aug 24 Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 Python
Python递归及尾递归优化操作实例分析
Feb 01 Python
django中嵌套的try-except实例
May 21 Python
Python实现文件压缩和解压的示例代码
Aug 12 Python
Pytest之测试命名规则的使用
Apr 16 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
第四节--构造函数和析构函数
2006/11/16 PHP
php中计算时间差的几种方法
2009/12/31 PHP
php读取xml实例代码
2010/01/28 PHP
PHP实现搜索相似图片
2015/09/22 PHP
ThinkPHP中数据操作案例分析
2015/09/27 PHP
php自定义函数转换html标签示例
2016/09/29 PHP
PHP防止图片盗用(盗链)的方法小结
2016/11/11 PHP
Javascript 表单之间的数据传递代码
2008/12/04 Javascript
jQuery遍历Form示例代码
2013/09/03 Javascript
thinkphp中常用的系统常量和系统变量
2014/03/05 Javascript
JSON+Jquery省市区三级联动
2016/01/13 Javascript
bootstrap和jQuery.Gantt的css冲突 如何解决
2016/05/29 Javascript
JavaScript的instanceof运算符学习教程
2016/06/08 Javascript
JS图片压缩(pc端和移动端都适用)
2017/01/12 Javascript
详谈Ajax请求中的async:false/true的作用(ajax 在外部调用问题)
2017/02/10 Javascript
原生JS上传大文件显示进度条 php上传文件代码
2020/03/27 Javascript
利用vue-i18n实现多语言切换效果的方法
2019/06/19 Javascript
vue父子模板传值问题解决方法案例分析
2020/02/26 Javascript
Ant Design Vue table中列超长显示...并加提示语的实例
2020/10/31 Javascript
Python通过PIL获取图片主要颜色并和颜色库进行对比的方法
2015/03/19 Python
Python的Urllib库的基本使用教程
2015/04/30 Python
Python实现简单查找最长子串功能示例
2019/02/26 Python
numpy下的flatten()函数用法详解
2019/05/27 Python
Python 用三行代码提取PDF表格数据
2019/10/13 Python
python 串行执行和并行执行实例
2020/04/30 Python
python如何随机生成高强度密码
2020/08/19 Python
通过实例解析Python文件操作实现步骤
2020/09/21 Python
Selenium Webdriver元素定位的八种常用方式(小结)
2021/01/13 Python
matplotlib绘制正余弦曲线图的实现
2021/02/22 Python
Yankee Candle官网:美国最畅销蜡烛品牌之一
2020/01/05 全球购物
计算机大学生的自我评价
2013/10/15 职场文书
八项规定整改措施
2014/02/12 职场文书
最美孝心少年事迹材料
2014/08/15 职场文书
继承权公证书范本
2015/01/23 职场文书
倡议书怎么写?
2019/04/11 职场文书
导游词之四川熊猫基地
2020/01/13 职场文书