tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用PyGreSQL操作PostgreSQL数据库教程
Jul 30 Python
python TCP Socket的粘包和分包的处理详解
Feb 09 Python
python实现狄克斯特拉算法
Jan 17 Python
Django JWT Token RestfulAPI用户认证详解
Jan 23 Python
Python Pexpect库的简单使用方法
Jan 29 Python
Python3.5基础之变量、数据结构、条件和循环语句、break与continue语句实例详解
Apr 26 Python
python简单验证码识别的实现方法
May 10 Python
pyqt5 获取显示器的分辨率的方法
Jun 18 Python
python3 打印输出字典中特定的某个key的方法示例
Jul 06 Python
基于MATLAB和Python实现MFCC特征参数提取
Aug 13 Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 Python
python turtle绘图
May 04 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
php中cookie实现二级域名可访问操作的方法
2014/11/11 PHP
php计算给定时间之前的函数用法实例
2015/04/03 PHP
在Linux系统的服务器上隐藏PHP版本号的方法
2015/06/06 PHP
Laravel框架使用Seeder实现自动填充数据功能
2018/06/13 PHP
Laravel访问出错提示:`Warning: require(/vendor/autoload.php): failed to open stream: No such file or di解决方法
2019/04/02 PHP
Thinkphp5.0框架视图view的循环标签用法示例
2019/10/12 PHP
利用WebBrowser彻底解决Web打印问题(包括后台打印)
2009/06/22 Javascript
简介JavaScript中的setTime()方法的使用
2015/06/11 Javascript
jQuery鼠标经过方形图片切换成圆边效果代码分享
2015/08/20 Javascript
js检测iframe是否加载完成的方法
2015/11/26 Javascript
Bootstrap每天必学之标签页(Tab)插件
2020/08/09 Javascript
跨域资源共享 CORS 详解
2016/04/26 Javascript
JS 对象(Object)和字符串(String)互转方法
2016/05/20 Javascript
Javascript日期格式化format函数的使用方法
2016/08/30 Javascript
浅谈javascript中的三种弹窗
2016/10/21 Javascript
Vue.js如何优雅的进行form validation
2017/04/07 Javascript
Angular.JS中指令ng-if、ng-show/ng-hide和ng-switch的使用教程
2017/05/07 Javascript
浅谈js基础数据类型和引用类型,深浅拷贝问题,以及内存分配问题
2017/09/02 Javascript
js中document.write和document.writeln的区别
2018/03/11 Javascript
[03:37]2014DOTA2国际邀请赛 主赛事第一日胜者组TOPPLAY
2014/07/19 DOTA
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
2014/05/05 Python
python实现简单的TCP代理服务器
2014/10/08 Python
详解在Python和IPython中使用Docker
2015/04/28 Python
Python判断字符串与大小写转换
2015/06/08 Python
安装PyInstaller失败问题解决
2019/12/14 Python
Pandas之read_csv()读取文件跳过报错行的解决
2020/04/21 Python
html5新增的定时器requestAnimationFrame实现进度条功能
2018/12/13 HTML / CSS
什么是虚拟内存?虚拟内存有什么优势?
2012/02/19 面试题
销售人员自我评价怎么写
2013/09/19 职场文书
文秘专业大学生求职信
2013/11/10 职场文书
大学生优秀团员事迹材料
2014/01/30 职场文书
《鞋匠的儿子》教学反思
2014/03/02 职场文书
法律六进活动方案
2014/03/13 职场文书
安全宣传标语口号
2014/06/06 职场文书
2014年家长学校工作总结
2014/11/20 职场文书
司考复习计划
2015/01/19 职场文书