tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用稀疏矩阵节省内存实例
Jun 27 Python
Web服务器框架 Tornado简介
Jul 16 Python
python基于socket实现网络广播的方法
Apr 29 Python
python简单实现刷新智联简历
Mar 30 Python
浅谈python中的正则表达式(re模块)
Oct 17 Python
分析Python中解析构建数据知识
Jan 20 Python
djano一对一、多对多、分页实例代码
Aug 16 Python
python 函数的缺省参数使用注意事项分析
Sep 17 Python
使用批处理脚本自动生成并上传NuGet包(操作方法)
Nov 19 Python
利用PyCharm操作Github(仓库新建、更新,代码回滚)
Dec 18 Python
python自动化测试三部曲之unittest框架的实现
Oct 07 Python
python实现MySQL指定表增量同步数据到clickhouse的脚本
Feb 26 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
用PHP实现的随机广告显示代码
2007/06/14 PHP
PHP读取文件并可支持远程文件的代码分享
2012/10/03 PHP
深入php socket的讲解与实例分析
2013/06/13 PHP
laravel excel 上传文件保存到本地服务器功能
2019/11/14 PHP
Node.js异步I/O学习笔记
2014/11/04 Javascript
Node.js中调用mysql存储过程示例
2014/12/20 Javascript
JS实现点击按钮后框架内载入不同网页的方法
2015/05/05 Javascript
bootstrap下拉列表与输入框组结合的样式调整
2016/10/08 Javascript
JS中用childNodes获取子元素换行会产生一个子元素
2016/12/08 Javascript
基于canvas的二维码邀请函生成插件
2017/02/14 Javascript
js实现城市级联菜单的2种方法
2017/06/23 Javascript
js实现鼠标移动到图片产生遮罩效果
2017/10/21 Javascript
JS实现的计数排序与基数排序算法示例
2017/12/04 Javascript
vue.js提交按钮时进行简单的if判断表达式详解
2018/08/08 Javascript
NodeJs 文件系统操作模块fs使用方法详解
2018/11/26 NodeJs
vue路由教程之静态路由
2019/09/03 Javascript
vue本地打开build后生成的dist文件夹index.html问题
2019/09/04 Javascript
JavaScript代理模式原理与用法实例详解
2020/03/10 Javascript
Vue CLI3移动端适配(px2rem或postcss-plugin-px2rem)
2020/04/27 Javascript
深入理解Python中的*重复运算符
2017/10/28 Python
PyQt5实现下载进度条效果
2018/04/19 Python
python实现人人自动回复、抢沙发功能
2018/06/08 Python
Python3导入CSV文件的实例(跟Python2有些许的不同)
2018/06/22 Python
Python拆分大型CSV文件代码实例
2019/10/07 Python
python对批量WAV音频进行等长分割的方法实现
2020/09/25 Python
基于html5 canvas做批改作业的小插件
2020/05/20 HTML / CSS
什么是GWT的Module
2013/01/20 面试题
受欢迎的大学生自我评价
2013/12/05 职场文书
村捐赠仪式答谢词
2014/01/21 职场文书
家居饰品店创业计划书
2014/01/31 职场文书
个人四风问题对照检查材料
2014/10/01 职场文书
领导班子个人查摆问题对照检查材料
2014/10/02 职场文书
股东授权委托书
2014/10/15 职场文书
电力安全学习心得体会
2016/01/18 职场文书
2016年助残日旅游活动总结
2016/04/01 职场文书
导游词之任弼时故居
2020/01/07 职场文书