tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pycharm学习教程(1) 定制外观
May 02 Python
Python使用openpyxl读写excel文件的方法
Jun 30 Python
python中urlparse模块介绍与使用示例
Nov 19 Python
pandas数值计算与排序方法
Apr 12 Python
python最长回文串算法
Jun 04 Python
python如何生成网页验证码
Jul 28 Python
python生成requirements.txt的两种方法
Sep 18 Python
python实现上传文件到linux指定目录的方法
Jan 03 Python
Python sklearn库实现PCA教程(以鸢尾花分类为例)
Feb 24 Python
Lombok插件安装(IDEA)及配置jar包使用详解
Nov 04 Python
Python中使用Lambda函数的5种用法
Apr 01 Python
Python批量解压&压缩文件夹的示例代码
Apr 04 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
一个用于mysql的数据库抽象层函数库
2006/10/09 PHP
PHP的cURL库功能简介 抓取网页、POST数据及其他
2011/04/07 PHP
PHP函数超时处理方法
2016/02/14 PHP
php获取数据库结果集方法(推荐)
2017/06/01 PHP
限制文本字节数js代码
2007/03/06 Javascript
js 分栏效果实现代码
2009/08/29 Javascript
原生js和jQuery随意改变div属性style的名称和值
2014/10/22 Javascript
深入分析下javascript中的[]()+!
2015/07/07 Javascript
JS去掉字符串中所有的逗号
2017/10/18 Javascript
微信小程序非swiper组件实现的自定义伪3D轮播图效果示例
2018/12/11 Javascript
详解Vue.js中引入图片路径的几种方式
2019/06/17 Javascript
使用flow来规范javascript的变量类型
2019/09/12 Javascript
8个非常实用的Vue自定义指令
2020/12/15 Vue.js
读取本地json文件,解析json(实例讲解)
2017/12/06 Python
PyQt5每天必学之进度条效果
2018/04/19 Python
Pyqt实现无边框窗口拖动以及窗口大小改变
2018/04/19 Python
Django中的ajax请求
2018/10/19 Python
对python当中不在本路径的py文件的引用详解
2018/12/15 Python
python sklearn常用分类算法模型的调用
2019/10/16 Python
用OpenCV进行年龄和性别检测的实现示例
2021/01/29 Python
Python爬取酷狗MP3音频的步骤
2021/02/26 Python
XD健身器材:Kevlar球、Crossfit健身球
2019/03/26 全球购物
Sperry澳大利亚官网:源自美国帆船鞋创始品牌
2019/07/29 全球购物
焊接专业毕业生求职信
2013/10/01 职场文书
施工资料员岗位职责
2014/01/06 职场文书
新春寄语大全
2014/04/09 职场文书
抗洪抢险事迹材料
2014/05/06 职场文书
教师竞聘上岗演讲稿
2014/09/03 职场文书
合作经营协议书范本
2014/09/16 职场文书
走群众路线学习笔记
2014/11/06 职场文书
优秀党员先进事迹材料
2014/12/18 职场文书
党员“一帮一”活动总结
2015/05/07 职场文书
上课讲话检讨书范文
2015/05/07 职场文书
厉行节约工作总结
2015/08/12 职场文书
导游带团欢迎词
2015/09/30 职场文书
Python基本的内置数据类型及使用方法
2022/04/13 Python