tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的CURL PycURL使用例子
Jun 01 Python
Tensorflow卷积神经网络实例
May 24 Python
Python使用分布式锁的代码演示示例
Jul 30 Python
Python线性拟合实现函数与用法示例
Dec 13 Python
python中dict使用方法详解
Jul 17 Python
手把手教你pycharm专业版安装破解教程(linux版)
Sep 26 Python
Python threading模块condition原理及运行流程详解
Oct 05 Python
python使用selenium爬虫知乎的方法示例
Oct 28 Python
python爬虫使用scrapy注意事项
Nov 23 Python
python3 通过 pybind11 使用Eigen加速代码的步骤详解
Dec 07 Python
DRF使用simple JWT身份验证的实现
Jan 14 Python
pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异
Feb 25 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
十天学会php之第八天
2006/10/09 PHP
php标签云的实现代码
2012/10/10 PHP
thinkphp集成前端脚手架Vue-cli的教程图解
2018/08/30 PHP
thinkPHP框架通过Redis实现增删改查操作的方法详解
2019/05/13 PHP
laravel框架模型和数据库基础操作实例详解
2020/01/25 PHP
TP5框架实现上传多张图片的方法分析
2020/03/29 PHP
jquery设置text的值示例(设置文本框 DIV 表单值)
2014/01/06 Javascript
全面解析Bootstrap中tooltip、popover的使用方法
2016/06/13 Javascript
利用JS实现点击按钮后图片自动切换的简单方法
2016/10/24 Javascript
基于jQuery实现瀑布流页面
2017/04/11 jQuery
Angular 2 利用Router事件和Title实现动态页面标题的方法
2017/08/23 Javascript
webpack实现一个行内样式px转vw的loader示例
2018/09/13 Javascript
新年快乐! javascript实现超级炫酷的3D烟花特效
2019/01/30 Javascript
通过实例讲解JS如何防抖动
2019/06/15 Javascript
Python中pip安装非PyPI官网第三方库的方法
2015/06/02 Python
详解Django中的过滤器
2015/07/16 Python
Python selenium如何设置等待时间
2016/09/15 Python
Python 爬虫学习笔记之单线程爬虫
2016/09/21 Python
Python实现可自定义大小的截屏功能
2018/01/20 Python
python得到电脑的开机时间方法
2018/10/15 Python
Python使用pyserial进行串口通信的实例
2019/07/02 Python
Python socket非阻塞模块应用示例
2019/09/12 Python
python使用opencv实现马赛克效果示例
2019/09/28 Python
VS2019+python3.7+opencv4.1+tensorflow1.13配置详解
2020/04/16 Python
css3+jq创作含苞待放的荷花
2014/02/20 HTML / CSS
福克斯租车:Fox Rent A Car
2017/04/13 全球购物
华为的Java面试题
2014/03/07 面试题
我们的节日清明节活动方案
2014/03/05 职场文书
工作保证书范文
2014/04/29 职场文书
大学社团计划书
2014/05/01 职场文书
员工三分钟演讲稿
2014/08/19 职场文书
党员四风自我剖析材料
2014/10/07 职场文书
小学生成绩单评语
2014/12/31 职场文书
教师工作决心书
2015/02/04 职场文书
安全教育培训心得体会
2016/01/15 职场文书
创业计划书之面包店
2019/09/12 职场文书