tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现给字典添加条目的方法
Sep 25 Python
python实现下载整个ftp目录的方法
Jan 17 Python
用python写一个windows下的定时关机脚本(推荐)
Mar 21 Python
基于scrapy的redis安装和配置方法
Jun 13 Python
python3.6使用pickle序列化class的方法
Oct 22 Python
Python批处理更改文件名os.rename的方法
Oct 26 Python
Python定义函数功能与用法实例详解
Apr 08 Python
python常用数据重复项处理方法
Nov 22 Python
Python图像处理库PIL的ImageGrab模块介绍详解
Feb 26 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
Sep 05 Python
解决pytorch 保存模型遇到的问题
Mar 03 Python
python如何做代码性能分析
Apr 26 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
thinkphp的CURD和查询方式介绍
2013/12/19 PHP
PHP把JPEG图片转换成Progressive JPEG的方法
2014/06/30 PHP
2个比较经典的PHP加密解密函数分享
2014/07/01 PHP
PHP贪婪算法解决0-1背包问题实例分析
2015/03/23 PHP
在WordPress中使用wp-cron插件来设置定时任务
2015/12/10 PHP
ThinkPHP静态缓存简单配置和使用方法详解
2016/03/23 PHP
JSON字符串传到后台PHP处理问题的解决方法
2016/06/05 PHP
自定义一个jquery插件[鼠标悬浮时候 出现说明label]
2011/06/27 Javascript
javascipt匹配单行和多行注释的正则表达式
2013/11/20 Javascript
淘宝网提供的国内NPM镜像简介和使用方法
2014/04/17 Javascript
教你用jquery实现iframe自适应高度
2014/06/11 Javascript
5个数组Array方法: indexOf、filter、forEach、map、reduce使用实例
2015/01/29 Javascript
JavaScript+CSS实现仿Mootools竖排弹性动画菜单效果
2015/10/14 Javascript
微信小程序 安全包括(框架、功能模块、账户使用)详解
2017/01/16 Javascript
如何使用Bootstrap 按钮实例详解
2017/03/29 Javascript
JS基于贪心算法解决背包问题示例
2017/11/27 Javascript
vue项目中锚点定位替代方式
2019/11/13 Javascript
Vue+Vuex实现自动登录的知识点详解
2020/03/04 Javascript
vue实现简易图片左右旋转,上一张,下一张组件案例
2020/07/31 Javascript
Python 字典与字符串的互转实例
2017/01/13 Python
python如何发布自已pip项目的方法步骤
2018/10/09 Python
Python WEB应用部署的实现方法
2019/01/02 Python
django执行原始查询sql,并返回Dict字典例子
2020/04/01 Python
python库skimage给灰度图像染色的方法示例
2020/04/27 Python
基于Python的接口自动化读写excel文件的方法
2021/01/15 Python
魅力教师事迹材料
2014/01/10 职场文书
商场中秋节广播稿
2014/01/17 职场文书
老公保证书范文
2014/04/29 职场文书
小学阳光体育活动总结
2014/07/05 职场文书
小学师德师风整改措施
2014/10/27 职场文书
2014年妇联工作总结
2014/11/21 职场文书
婚宴父母致辞
2015/07/27 职场文书
辞职离别感言
2015/08/04 职场文书
请学会珍惜眼前,因为人生没有下辈子!
2019/11/12 职场文书
基于nginx实现上游服务器动态自动上下线无需reload的实现方法
2021/03/31 Servers
python中的random模块和相关函数详解
2022/04/22 Python