tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python SQLite3数据库操作类分享
Jun 10 Python
用Python编写一个简单的Lisp解释器的教程
Apr 03 Python
火车票抢票python代码公开揭秘!
Mar 08 Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 Python
使用Python 自动生成 Word 文档的教程
Feb 13 Python
python 回溯法模板详解
Feb 26 Python
在Mac中PyCharm配置python Anaconda环境过程图解
Mar 11 Python
Django bulk_create()、update()与数据库事务的效率对比分析
May 15 Python
python字典key不能是可以是啥类型
Aug 04 Python
Python并发爬虫常用实现方法解析
Nov 19 Python
python反扒机制的5种解决方法
Feb 06 Python
利用python做表格数据处理
Apr 13 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
PHP4 与 MySQL 数据库操作函数详解
2006/12/06 PHP
PDO::prepare讲解
2019/01/29 PHP
Yii框架组件的事件机制原理与用法分析
2020/04/07 PHP
PHPStorm2020.1永久激活及下载更新至2020(推荐)
2020/09/25 PHP
jQuery 源码分析笔记(3) Deferred机制
2011/06/19 Javascript
Javascript变量作用域详解
2013/12/06 Javascript
鼠标移入移出事件改变图片的分辨率的两种方法
2013/12/17 Javascript
js导航栏单击事件背景变换示例代码
2014/01/13 Javascript
javascript关于继承解析
2016/05/10 Javascript
异步加载JS、CSS代码(推荐)
2016/06/15 Javascript
Node.js Streams文件读写操作详解
2016/07/04 Javascript
微信小程序 教程之注册程序
2016/10/17 Javascript
详解JS几种变量交换方式以及性能分析对比
2016/11/25 Javascript
javascript实现简单的可随机变色网页计算器示例
2016/12/30 Javascript
原生js实现倒计时功能(多种格式调用)
2017/01/12 Javascript
一个简易的js图片轮播效果
2017/07/22 Javascript
浅谈Angular4实现热加载开发旅程
2017/09/08 Javascript
JS将网址url转化为JSON格式的方法
2018/07/02 Javascript
JS判断数组里是否有重复元素的方法小结
2019/05/21 Javascript
React Native中ScrollView组件轮播图与ListView渲染列表组件用法实例分析
2020/01/06 Javascript
微信小程序wx.getUserInfo授权获取用户信息(头像、昵称)的实现
2020/08/19 Javascript
python中将字典转换成其json字符串
2014/07/16 Python
Python之用户输入的实例
2018/06/22 Python
python分块读取大数据,避免内存不足的方法
2018/12/10 Python
在python中获取div的文本内容并和想定结果进行对比详解
2019/01/02 Python
在服务器上安装python3.8.2环境的教程详解
2020/04/26 Python
使用Keras构造简单的CNN网络实例
2020/06/29 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
2020/11/20 Python
万户网络JAVA程序员岗位招聘笔试试卷
2013/01/08 面试题
体育教师自荐信范文
2013/12/16 职场文书
会计与出纳自荐书范文
2014/03/16 职场文书
小学生手册家长评语
2014/04/16 职场文书
领导干部学习“三严三实”思想汇报
2014/09/15 职场文书
学校安全管理制度
2015/08/06 职场文书
实习报告范文之电话客服岗位
2019/07/26 职场文书
CSS3 制作的彩虹按钮样式
2021/04/11 HTML / CSS