tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 列表(List)操作方法详解
Mar 11 Python
python实现进程间通信简单实例
Jul 23 Python
深入理解python中的select模块
Apr 23 Python
python的继承知识点总结
Dec 10 Python
python实现基于信息增益的决策树归纳
Dec 18 Python
Python文件操作中进行字符串替换的方法(保存到新文件/当前文件)
Jun 28 Python
python实现的爬取电影下载链接功能示例
Aug 26 Python
python获取Linux发行版名称
Aug 30 Python
详解python中docx库的安装过程
Nov 08 Python
Python中zip函数如何使用
Jun 04 Python
Python logging模块异步线程写日志实现过程解析
Jun 30 Python
Elasticsearch 聚合查询和排序
Apr 19 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
require(),include(),require_once()和include_once()的异同
2007/01/02 PHP
php在线生成ico文件的代码
2007/10/09 PHP
探讨GDFONTPATH能否被winxp下的php支持
2013/06/21 PHP
网页打开自动最大化的js代码
2012/08/22 Javascript
Javascript中自动切换焦点实现代码
2012/12/15 Javascript
JavaScript设置body高度为浏览器高度的方法
2015/02/09 Javascript
基于JS实现密码框(password)中显示文字提示功能代码
2016/05/27 Javascript
js调用父框架函数与弹窗调用父页面函数的简单方法
2016/11/01 Javascript
关于Vue.js 2.0的Vuex 2.0 你需要更新的知识库
2016/11/30 Javascript
JS实现颜色梯度与渐变效果完整实例
2016/12/30 Javascript
JavaScript实现删除数组重复元素的5种常用高效算法总结
2018/01/18 Javascript
vuejs前后端数据交互之从后端请求数据的实例
2018/08/11 Javascript
JS实现利用闭包判断Dom元素和滚动条的方向示例
2019/08/26 Javascript
JS+Canvas实现五子棋游戏
2020/08/26 Javascript
Python自动化测试ConfigParser模块读写配置文件
2016/08/15 Python
python数据结构之链表的实例讲解
2017/07/25 Python
Python实现的井字棋(Tic Tac Toe)游戏示例
2018/01/31 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
2018/03/05 Python
python将类似json的数据存储到MySQL中的实例
2019/07/12 Python
对python中不同模块(函数、类、变量)的调用详解
2019/07/16 Python
Python学习笔记之Zip和Enumerate用法实例分析
2019/08/14 Python
Pycharm debug调试时带参数过程解析
2020/02/03 Python
如何解决pycharm调试报错的问题
2020/08/06 Python
adidas官方旗舰店:德国运动用品制造商
2017/11/25 全球购物
应届生妇产科护士求职信
2013/10/27 职场文书
优秀学生获奖感言
2014/02/15 职场文书
大学生两会精神学习心得体会
2014/03/10 职场文书
初三开学计划书
2014/04/27 职场文书
个人租房协议书
2014/11/28 职场文书
师德师风事迹材料
2014/12/20 职场文书
教师工作决心书
2015/02/04 职场文书
社区五一劳动节活动总结
2015/02/09 职场文书
昆虫记读书笔记
2015/06/26 职场文书
导游词之阆中古城
2019/12/23 职场文书
一文搞懂php的垃圾回收机制
2021/06/18 PHP
python 标准库原理与用法详解之os.path篇
2021/10/24 Python