tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 随机数生成的代码的详细分析
May 15 Python
使用Python的urllib和urllib2模块制作爬虫的实例教程
Jan 20 Python
Django模板变量如何传递给外部js调用的方法小结
Jul 24 Python
Python抓取框架Scrapy爬虫入门:页面提取
Dec 01 Python
Python解决线性代数问题之矩阵的初等变换方法
Dec 12 Python
解决Pycharm界面的子窗口不见了的问题
Jan 17 Python
python 实现selenium断言和验证的方法
Feb 13 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 Python
Django密码存储策略分析
Jan 09 Python
如何通过Python3和ssl实现加密通信功能
May 09 Python
Python用access判断文件是否被占用的实例方法
Dec 17 Python
Django实现WebSocket在线聊天室功能(channels库)
Sep 25 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
打造计数器DIY三步曲(上)
2006/10/09 PHP
discuz7 phpMysql操作类
2009/06/21 PHP
深入分析使用mysql_fetch_object()以对象的形式返回查询结果
2013/06/05 PHP
hadoop常见错误以及处理方法详解
2013/06/19 PHP
php输入流php://input使用示例(php发送图片流到服务器)
2013/12/25 PHP
Yii2简单实现多语言配置的方法
2016/07/23 PHP
Linux平台PHP5.4设置FPM线程数量的方法
2016/11/09 PHP
PHP的imageTtfText()函数深入详解
2021/03/03 PHP
收藏Javascript中常用的55个经典技巧
2007/08/12 Javascript
基于JQuery的cookie插件
2010/04/07 Javascript
IE6下js通过css隐藏select的一个bug
2010/08/16 Javascript
提高javascript效率 一次判断,而不要次次判断
2012/03/30 Javascript
jquery插件之定时查询待处理任务数量
2014/05/01 Javascript
jquery删除数据记录时的弹出提示效果
2014/05/06 Javascript
JS中获取函数调用链所有参数的方法
2015/05/07 Javascript
详解AngularJS 模态对话框
2016/04/07 Javascript
原生js的RSA和AES加密解密算法
2016/10/08 Javascript
JavaScript实现简单轮播图效果
2018/12/01 Javascript
如何在vue里面优雅的解决跨域(路由冲突问题)
2019/01/20 Javascript
关于Vue源码vm.$watch()内部原理详解
2019/04/26 Javascript
vue实现在进行增删改操作后刷新页面
2020/08/05 Javascript
python开发中module模块用法实例分析
2015/11/12 Python
Python编写登陆接口的方法
2017/07/10 Python
20个常用Python运维库和模块
2018/02/12 Python
Windows下Anaconda2安装NLTK教程
2018/09/19 Python
Python selenium爬虫实现定时任务过程解析
2020/06/08 Python
Python字符串三种格式化输出
2020/09/17 Python
Python页面加载的等待方式总结
2021/02/28 Python
澳大利亚在线购买葡萄酒:The Wine Collective
2020/02/20 全球购物
教学器材管理制度
2014/01/26 职场文书
秋游活动策划方案
2014/02/16 职场文书
赡养老人协议书
2014/04/21 职场文书
2015年班干部工作总结
2015/04/29 职场文书
2019军训心得体会
2019/06/27 职场文书
PostgreSQL事务回卷实战案例详析
2022/03/25 PostgreSQL
Win10玩csgo闪退如何解决?Win10玩csgo闪退的解决方法
2022/07/23 数码科技