tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Mac OS X10.9安装的Python2.7升级Python3.3步骤详解
Dec 04 Python
python使用xlrd实现检索excel中某列含有指定字符串记录的方法
May 09 Python
python先序遍历二叉树问题
Nov 10 Python
pandas object格式转float64格式的方法
Apr 10 Python
python基础知识(一)变量与简单数据类型详解
Apr 17 Python
使用TensorFlow-Slim进行图像分类的实现
Dec 31 Python
Python打开文件、文件读写操作、with方式、文件常用函数实例分析
Jan 07 Python
python如何爬取动态网站
Sep 09 Python
python爬虫基础之urllib的使用
Dec 31 Python
python实现简单的井字棋游戏(gui界面)
Jan 22 Python
关于PySnooper 永远不要使用print进行调试的问题
Mar 04 Python
Windows安装Anaconda3的方法及使用过程详解
Jun 11 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
一个PHP的ZIP压缩类分享
2014/05/04 PHP
Zend Framework入门教程之Zend_Session会话操作详解
2016/12/08 PHP
redis+php实现微博(三)微博列表功能详解
2019/09/23 PHP
PHP迭代器和生成器用法实例分析
2019/09/28 PHP
jQuery封装的获取Url中的Get参数示例
2013/11/26 Javascript
原生js和jquery中有关透明度设置的相关问题
2014/01/08 Javascript
nodejs下打包模块archiver详解
2014/12/03 NodeJs
创建你的第一个AngularJS应用的方法
2015/06/16 Javascript
node.js实现端口转发
2016/04/14 Javascript
在Node.js中使用Javascript Generators详解
2016/05/05 Javascript
JS判断字符串变量是否含有某个字串的实现方法
2016/06/03 Javascript
jQueryUI DatePicker 添加时分秒
2016/06/04 Javascript
AngularJS中transclude用法详解
2016/11/03 Javascript
通过jsonp获取json数据实现AJAX跨域请求
2017/01/22 Javascript
js时间查询插件使用详解
2017/04/07 Javascript
快速对接payjq的个人微信支付接口过程解析
2019/08/15 Javascript
vue中音频wavesurfer.js的使用方法
2020/02/20 Vue.js
vue-resource post数据时碰到Django csrf问题的解决
2020/03/13 Javascript
python通过自定义isnumber函数判断字符串是否为数字的方法
2015/04/23 Python
Python3实现腾讯云OCR识别
2018/11/27 Python
python3实现逐字输出的方法
2019/01/23 Python
Python制作词云图代码实例
2019/09/09 Python
python3.7将代码打包成exe程序并添加图标的方法
2019/10/11 Python
python单向链表的基本实现与使用方法【定义、遍历、添加、删除、查找等】
2019/10/24 Python
django 简单实现登录验证给你
2019/11/06 Python
Python semaphore evevt生产者消费者模型原理解析
2020/03/18 Python
Europcar西班牙:全球汽车租赁领域的领导者
2018/09/17 全球购物
Travelstart沙特阿拉伯:廉价航班、豪华酒店和实惠的汽车租赁优惠
2019/04/06 全球购物
班组长工作职责
2013/12/25 职场文书
工程专业毕业生自荐信范文
2013/12/25 职场文书
教学实验楼管理制度
2014/02/01 职场文书
运动会入场词50字
2014/02/20 职场文书
2014年乡镇卫生院工作总结
2014/11/24 职场文书
公司保洁员岗位职责
2015/02/13 职场文书
2015年幼儿园班主任工作总结
2015/05/12 职场文书
素质拓展训练感想
2015/08/07 职场文书