tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的多线程实例教程
Aug 27 Python
在Django中创建动态视图的教程
Jul 15 Python
Python基于select实现的socket服务器
Apr 13 Python
Python爬虫之模拟知乎登录的方法教程
May 25 Python
requests和lxml实现爬虫的方法
Jun 11 Python
python 绘制拟合曲线并加指定点标识的实现
Jul 10 Python
基于Python实现大文件分割和命名脚本过程解析
Sep 29 Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 Python
解决import tensorflow as tf 出错的原因
Apr 16 Python
Python实现动态循环输出文字功能
May 07 Python
python输入中文的实例方法
Sep 14 Python
Python使用pandas导入csv文件内容的示例代码
Dec 24 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
php基础知识:类与对象(1)
2006/12/13 PHP
坏狼的PHP学习教程之第1天
2008/06/15 PHP
Linux环境下搭建php开发环境的操作步骤
2013/06/17 PHP
php根据数据id自动生成编号的实现方法
2016/10/16 PHP
asp.net HttpHandler实现图片防盗链
2009/11/09 Javascript
jQuery EasyUI API 中文文档 - MenuButton菜单按钮使用介绍
2011/10/06 Javascript
CSS鼠标响应事件经过、移动、点击示例介绍
2013/09/04 Javascript
Bootstrap每天必学之表单
2015/11/23 Javascript
JavaScript使用DeviceOne开发实战(二) 生成调试安装包
2015/12/01 Javascript
使用PBFunc在Powerbuilder中支付宝当面付款功能
2016/10/01 Javascript
js实现随机数字字母验证码
2017/06/19 Javascript
vue 请求后台数据的实例代码
2017/06/22 Javascript
jQuery实现全选、反选和不选功能
2017/08/16 jQuery
iview中Select 选择器多选校验方法
2018/03/15 Javascript
实例详解vue.js浅度监听和深度监听及watch用法
2018/08/16 Javascript
jQuery实现的老虎机跑动效果示例
2018/12/29 jQuery
VUE实现密码验证与提示功能
2019/10/18 Javascript
js抽奖转盘实现方法分析
2020/05/16 Javascript
js实现跳一跳小游戏
2020/07/31 Javascript
Vue中使用JsonView来展示Json树的实例代码
2020/11/16 Javascript
python时间整形转标准格式的示例分享
2014/02/14 Python
简单介绍Python中的try和finally和with方法
2015/05/05 Python
Python多进程并发(multiprocessing)用法实例详解
2015/06/02 Python
Python自动调用IE打开某个网站的方法
2015/06/03 Python
python 返回列表中某个值的索引方法
2018/11/07 Python
Python实现的批量修改文件后缀名操作示例
2018/12/07 Python
对python指数、幂数拟合curve_fit详解
2018/12/29 Python
python之MSE、MAE、RMSE的使用
2020/02/24 Python
Stutterheim瑞典:瑞典高级外套时装品牌
2019/06/24 全球购物
幼儿教师研修感言
2014/02/12 职场文书
大班亲子运动会方案
2014/06/10 职场文书
对照检查剖析材料
2014/09/30 职场文书
职称评定个人总结
2015/03/05 职场文书
《亲亲我的妈妈》观后感(3篇)
2019/09/26 职场文书
浅谈MySQL user权限表
2021/06/18 MySQL
JavaScript分页组件使用方法详解
2021/07/26 Javascript