tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python打印斐波拉契数列实例
Jul 07 Python
Python实现分割文件及合并文件的方法
Jul 10 Python
解决PyCharm import torch包失败的问题
Oct 13 Python
Python脚本按照当前日期创建多级目录
Mar 01 Python
如何使用Python破解ZIP或RAR压缩文件密码
Jan 09 Python
详解Django3中直接添加Websockets方式
Feb 12 Python
详解Python中import机制
Sep 11 Python
python字典通过值反查键的实现(简洁写法)
Sep 30 Python
python wsgiref源码解析
Feb 06 Python
python爬虫爬取某网站视频的示例代码
Feb 20 Python
Python使用Turtle模块绘制国旗的方法示例
Feb 28 Python
python数字图像处理:图像的绘制
Jun 28 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
php排序算法(冒泡排序,快速排序)
2012/10/09 PHP
yii框架表单模型使用及以数组形式提交表单数据示例
2014/04/30 PHP
php生成Android客户端扫描可登录的二维码
2016/05/13 PHP
PHP实现JS中escape与unescape的方法
2016/07/11 PHP
详解PHP归并排序的实现
2016/10/18 PHP
JQuery+DIV自定义滚动条样式的具体实现
2013/06/25 Javascript
jquery validation验证身份证号,护照,电话号码,email(实例代码)
2013/11/06 Javascript
js实现Form栏显示全格式时间时钟效果代码
2015/08/19 Javascript
jquery可定制的在线UEditor编辑器
2015/11/17 Javascript
gulp解决跨域的配置文件问题
2017/06/08 Javascript
jQuery中内容过滤器简单用法示例
2018/03/31 jQuery
Vue路由的模块自动化与统一加载实现
2020/06/05 Javascript
JavaScript实现串行请求的示例代码
2020/09/14 Javascript
uniapp实现可滑动选项卡
2020/10/21 Javascript
[01:31:22]DOTA2-DPC中国联赛定级赛 LBZS vs Magma BO3第二场 1月10日
2021/03/11 DOTA
Python开发WebService系列教程之REST,web.py,eurasia,Django
2014/06/30 Python
python自动格式化json文件的方法
2015/03/11 Python
python开发之for循环操作实例详解
2015/11/12 Python
简单谈谈python中的多进程
2016/11/06 Python
Python使用Pandas库实现MySQL数据库的读写
2019/07/06 Python
Python之time模块的时间戳,时间字符串格式化与转换方法(13位时间戳)
2019/08/12 Python
利用python实现周期财务统计可视化
2019/08/25 Python
用Python解数独的方法示例
2019/10/24 Python
Python imageio读取视频并进行编解码详解
2019/12/10 Python
TensorFlow MNIST手写数据集的实现方法
2020/02/05 Python
django 将自带的数据库sqlite3改成mysql实例
2020/07/09 Python
html5 canvas的绘制文本自动换行的示例代码
2018/09/17 HTML / CSS
Speedo速比涛德国官方网站:世界领先的泳装品牌
2019/08/26 全球购物
Vivo俄罗斯官方在线商店:中国智能手机品牌
2019/10/04 全球购物
领导检查欢迎词
2014/01/14 职场文书
歌唱比赛主持词
2014/03/18 职场文书
软件工程毕业生自荐信
2014/07/04 职场文书
中秋晚会活动方案
2014/08/31 职场文书
污染环境建议书
2015/09/14 职场文书
励志语录:你若不勇敢,谁替你坚强
2019/11/08 职场文书
企业版Windows 11有哪些新功能? Win11适用于企业的功能介绍
2021/11/21 数码科技