tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python高手之路python处理excel文件(方法汇总)
Jan 07 Python
Python将8位的图片转为24位的图片实现方法
Oct 24 Python
python生成以及打开json、csv和txt文件的实例
Nov 16 Python
Python图像处理之图片文字识别功能(OCR)
Jul 30 Python
python是否适合网页编程详解
Oct 04 Python
使用Python完成15位18位身份证的互转功能
Nov 06 Python
python pprint模块中print()和pprint()两者的区别
Feb 10 Python
python json.dumps中文乱码问题解决
Apr 01 Python
详解Python 循环嵌套
Jul 09 Python
python的launcher用法知识点总结
Aug 07 Python
详解python爬取弹幕与数据分析
Nov 14 Python
python中判断数字是否为质数的实例讲解
Dec 06 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
php数组对百万数据进行排除重复数据的实现代码
2010/06/08 PHP
php中print(),print_r(),echo()的区别详解
2014/12/01 PHP
如何修改yii2.0自带的user表为其它的表
2017/08/01 PHP
javascript动画效果类封装代码
2007/08/28 Javascript
Javascript中的var_dump函数实现代码
2009/09/07 Javascript
jQuery ajax在GBK编码下表单提交终极解决方案(非二次编码方法)
2010/10/20 Javascript
jQuery语法总结和注意事项小结
2012/11/11 Javascript
jQuery实现点击该行即可删除HTML表格行
2014/10/17 Javascript
jQuery中判断对象是否存在的方法汇总
2016/02/24 Javascript
JavaScript常用本地对象小结
2016/03/28 Javascript
JavaScript在form表单中使用button按钮实现submit提交方法
2017/01/23 Javascript
canvas实现探照灯效果
2017/02/07 Javascript
JavaScript算法教程之sku(库存量单位)详解
2017/06/29 Javascript
JavaScript实现简单的文本逐字打印效果示例
2018/04/12 Javascript
vue项目中jsonp跨域获取qq音乐首页推荐问题
2018/05/30 Javascript
Angular PWA使用的Demo示例
2019/01/31 Javascript
javascript中join方法实例讲解
2019/02/21 Javascript
js回调函数仿360开机
2019/12/26 Javascript
[47:43]完美世界DOTA2联赛PWL S3 Magama vs GXR 第二场 12.19
2020/12/24 DOTA
简述Python中的进程、线程、协程
2016/03/18 Python
在win和Linux系统中python命令行运行的不同
2016/07/03 Python
Python cookbook(数据结构与算法)从序列中移除重复项且保持元素间顺序不变的方法
2018/03/13 Python
使用python语言,比较两个字符串是否相同的实例
2018/06/29 Python
很酷的python表白工具 你喜欢我吗
2019/04/11 Python
Python跳出多重循环的方法示例
2019/07/03 Python
如何在python中实现随机选择
2019/11/02 Python
python批量处理txt文件的实例代码
2020/01/13 Python
Python进程间通信multiprocess代码实例
2020/03/18 Python
使用Python制作一盏 3D 花灯喜迎元宵佳节
2021/02/26 Python
解决PDF 转图片时丢文字的一种可能方式
2021/03/04 Python
CSS3 Flexbox中flex-shrink属性的用法示例介绍
2013/12/30 HTML / CSS
详解淘宝H5 sign加密算法
2020/08/25 HTML / CSS
经济贸易系毕业生求职信
2014/05/31 职场文书
法人委托书的范本格式
2014/09/11 职场文书
python 办公自动化——基于pyqt5和openpyxl统计符合要求的名单
2021/05/25 Python
vue使用wavesurfer.js解决音频可视化播放问题
2022/04/04 Vue.js