tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作xml文件示例
Apr 07 Python
Python赋值语句后逗号的作用分析
Jun 08 Python
Python的条件表达式和lambda表达式实例
Jan 31 Python
PyQt5实现简易计算器
May 30 Python
Python格式化字符串f-string概览(小结)
Jun 18 Python
python pandas写入excel文件的方法示例
Jun 25 Python
Python with关键字,上下文管理器,@contextmanager文件操作示例
Oct 17 Python
Python多线程thread及模块使用实例
Apr 28 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
利用python 读写csv文件
Sep 10 Python
Python tkinter制作单机五子棋游戏
Sep 14 Python
python脚本框架webpy的url映射详解
Nov 20 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
PHP分多步骤填写发布信息的简单方法实例代码
2012/09/23 PHP
php隐藏实际地址的文件下载方法
2015/04/18 PHP
简单谈谈PHP中的trait
2017/02/25 PHP
php+jQuery ajax实现的实时刷新显示数据功能示例
2019/09/12 PHP
PHP实现文件上传后台处理脚本
2020/03/04 PHP
初识JQuery 实例一(first)
2011/03/16 Javascript
原生JS可拖动弹窗效果实例代码
2013/11/09 Javascript
jquery实现类似淘宝星星评分功能实例
2014/09/12 Javascript
JavaScript改变CSS样式的方法汇总
2015/05/07 Javascript
基于JQuery实现的跑马灯效果(文字无缝向上翻动)
2016/12/02 Javascript
Jquery实现跨域异步上传文件总结
2017/02/03 Javascript
AngularJS学习第一篇 AngularJS基础知识
2017/02/13 Javascript
Three.js的使用及绘制基础3D图形详解
2017/04/27 Javascript
jQuery实现返回顶部按钮和scroll滚动功能[带动画效果]
2017/07/05 jQuery
bootstrap时间控件daterangepicker使用方法及各种小bug修复
2017/10/25 Javascript
结合mint-ui移动端下拉加载实践方法总结
2017/11/08 Javascript
EasyUI的DataGrid绑定Json数据源的示例代码
2017/12/16 Javascript
详解Node.js中的Async和Await函数
2018/02/22 Javascript
微信小程序自定义底部弹出框
2020/11/16 Javascript
js使用ajax传值给后台,后台返回字符串处理方法
2018/08/08 Javascript
详解JavaScript中的函数、对象
2019/04/01 Javascript
微信小程序开发之左右分栏效果的实例代码
2019/05/20 Javascript
node.js使用 http-proxy 创建代理服务器操作示例
2020/02/10 Javascript
python如何拆分含有多种分隔符的字符串
2018/03/20 Python
django-rest-swagger的优化使用方法
2019/08/29 Python
class类在python中获取金融数据的实例方法
2020/12/10 Python
Pycharm 设置默认解释器路径和编码格式的操作
2021/02/05 Python
HTML5 Canvas绘制文本及图片的基础教程
2016/03/14 HTML / CSS
G-Form护具官方网站:美国运动保护装备
2019/09/04 全球购物
中学生学习生活的自我评价
2013/10/26 职场文书
运动会广播稿300字
2014/01/10 职场文书
主持词开场白
2014/03/17 职场文书
导游词之天津古文化街
2019/11/09 职场文书
MySQL索引知识的一些小妙招总结
2021/05/10 MySQL
如何利用golang运用mysql数据库
2022/03/13 Golang
Vite + React从零开始搭建一个开源组件库
2022/06/25 Javascript