tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python统计文本文件内单词数量的方法
May 30 Python
python结合API实现即时天气信息
Jan 19 Python
Python基于time模块求程序运行时间的方法
Sep 18 Python
Python调用C# Com dll组件实战教程
Oct 12 Python
详解K-means算法在Python中的实现
Dec 05 Python
python使用Flask操作mysql实现登录功能
May 14 Python
利用Python进行数据可视化常见的9种方法!超实用!
Jul 11 Python
python实现感知机线性分类模型示例代码
Jun 02 Python
Python使用scipy模块实现一维卷积运算示例
Sep 05 Python
面向对象学习之pygame坦克大战
Sep 11 Python
python输出数组中指定元素的所有索引示例
Dec 06 Python
Python Pandas数据分析之iloc和loc的用法详解
Nov 11 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
php+jquery编码方面的一些心得(utf-8 gb2312)
2010/10/12 PHP
php 字符串替换的方法
2012/01/10 PHP
在win7中搭建Linux+PHP 开发环境
2014/10/08 PHP
ThinkPHP在新浪SAE平台的部署实例
2014/10/31 PHP
Laravel框架生命周期与原理分析
2018/06/12 PHP
Ajax+PHP实现的模拟进度条功能示例
2019/02/11 PHP
thinkPHP5.1框架中Request类四种调用方式示例
2019/08/03 PHP
javascript 日期时间函数(经典+完善+实用)
2009/05/27 Javascript
Jquery 1.42 checkbox 全选和反选代码
2010/03/27 Javascript
用js小类库获取浏览器的高度和宽度信息
2012/01/15 Javascript
javascript 实现子父窗体互相传值的简单实例
2014/02/17 Javascript
php的文件上传入门教程(实例讲解)
2014/04/10 Javascript
原生Javascript封装的一个AJAX函数分享
2014/10/11 Javascript
使用Raygun对Node.js应用进行错误处理的方法
2015/06/23 Javascript
js判断移动端是否安装某款app的多种方法
2015/12/18 Javascript
js实现图片缓慢放大缩小效果
2016/08/02 Javascript
vue+Java后端进行调试时解决跨域问题的方式
2017/10/19 Javascript
Vue.js实现可排序的表格组件功能示例
2019/02/19 Javascript
微信小程序上传图片到php服务器的方法
2019/05/23 Javascript
vue2路由基本用法实例分析
2020/03/06 Javascript
vue中移动端调取本地的复制的文本方式
2020/07/18 Javascript
python urllib爬取百度云连接的实例代码
2017/06/19 Python
基于Python对象引用、可变性和垃圾回收详解
2017/08/21 Python
浅谈Django REST Framework限速
2017/12/12 Python
详解Python安装scrapy的正确姿势
2018/06/26 Python
Python 的字典(Dict)是如何存储的
2019/07/05 Python
解决使用Pandas 读取超过65536行的Excel文件问题
2020/11/10 Python
Canvas 文字碰撞检测并抽稀的方法
2019/05/27 HTML / CSS
美国第二大连锁药店:Rite Aid
2019/04/03 全球购物
茱莉蔻美国官网:Jurlique美国
2020/11/24 全球购物
计算机专业推荐信范文
2013/11/27 职场文书
先进工作者获奖感言
2014/02/08 职场文书
古汉语文学求职信范文
2014/03/16 职场文书
2014年采购工作总结
2014/11/20 职场文书
2015年会计年终工作总结
2015/05/26 职场文书
pytorch实现加载保存查看checkpoint文件
2022/07/15 Python