tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Windows上气泡提醒效果的方法
Jun 03 Python
详解Python中的日志模块logging
Jun 19 Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 Python
python3实现字符串的全排列的方法(无重复字符)
Jul 07 Python
python3.4 将16进制转成字符串的实例
Jun 12 Python
TensorFlow车牌识别完整版代码(含车牌数据集)
Aug 05 Python
Python 字符串、列表、元组的截取与切片操作示例
Sep 17 Python
如何定义TensorFlow输入节点
Jan 23 Python
基于Python数据分析之pandas统计分析
Mar 03 Python
Windows下PyCharm配置Anaconda环境(超详细教程)
Jul 31 Python
一文搞懂Python Sklearn库使用
Aug 23 Python
Python批量解压&压缩文件夹的示例代码
Apr 04 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
将博客园(cnblogs.com)数据导入到wordpress的代码
2013/01/06 PHP
php  PATH_SEPARATOR判断当前服务器系统类型实例
2016/10/28 PHP
php引用和拷贝的区别知识点总结
2019/09/23 PHP
PHP框架实现WebSocket在线聊天通讯系统
2019/11/21 PHP
yii框架结合charjs统计上一年与当前年数据的方法示例
2020/04/04 PHP
PhpStorm2020 + phpstudyV8 +XDebug的教程详解
2020/09/17 PHP
laravel7学习之无限级分类的最新实现方法
2020/09/30 PHP
Javascript中的相等与不等运算
2010/04/25 Javascript
AngularJS基础知识
2014/12/21 Javascript
jQuery中insertBefore()方法用法实例
2015/01/08 Javascript
Javascript通过overflow控制列表闭合与展开的方法
2015/05/15 Javascript
ES6学习笔记之Set和Map数据结构详解
2017/04/07 Javascript
node puppeteer(headless chrome)实现网站登录
2018/05/09 Javascript
vue集成chart.js的实现方法
2019/08/20 Javascript
Javascript作用域和作用域链原理解析
2020/03/03 Javascript
Vue2.0 $set()的正确使用详解
2020/07/28 Javascript
一文秒懂nodejs中的异步编程
2021/01/28 NodeJs
python判断windows隐藏文件的方法
2014/03/21 Python
python进阶教程之文本文件的读取和写入
2014/08/29 Python
python实现的守护进程(Daemon)用法实例
2015/06/02 Python
python paramiko模块学习分享
2017/08/23 Python
Django admin美化插件suit使用示例
2017/12/12 Python
python3基于TCP实现CS架构文件传输
2018/07/28 Python
python交互界面的退出方法
2019/02/16 Python
python实现身份证实名认证的方法实例
2019/11/08 Python
django 框架实现的用户注册、登录、退出功能示例
2019/11/28 Python
TensorFLow 变量命名空间实例
2020/02/11 Python
Cole Haan官方网站:美国时尚潮流品牌
2017/12/06 全球购物
澳大利亚领先的时尚内衣零售商:Bras N Things
2020/07/28 全球购物
what is the difference between ext2 and ext3
2015/08/25 面试题
保护环境的建议书
2014/03/12 职场文书
三八红旗集体先进事迹材料
2014/05/22 职场文书
个人授权委托书格式
2014/08/30 职场文书
2014年环卫工作总结
2014/11/22 职场文书
西安兵马俑导游词
2015/02/02 职场文书
合作意向书范本
2019/04/17 职场文书