tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python获得时间的实例说明
Mar 25 Python
python之Socket网络编程详解
Sep 29 Python
基于Django filter中用contains和icontains的区别(详解)
Dec 12 Python
python 日志增量抓取实现方法
Apr 28 Python
python中virtualenvwrapper安装与使用
May 20 Python
Window环境下Scrapy开发环境搭建
Nov 18 Python
python opencv 图像拼接的实现方法
Jun 27 Python
python实现最大子序和(分治+动态规划)
Jul 05 Python
用Python获取摄像头并实时控制人脸的实现示例
Jul 11 Python
pandas中ix的使用详细讲解
Mar 09 Python
matlab、python中矩阵的互相导入导出方式
Jun 01 Python
Python检测端口IP字符串是否合法
Jun 05 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
用PHP动态生成虚拟现实VRML网页
2006/10/09 PHP
PHP开发环境配置(MySQL数据库安装图文教程)
2010/04/28 PHP
php图片上传存储源码并且可以预览
2011/08/26 PHP
php实现的Timer页面运行时间监测类
2014/09/24 PHP
JavaScript实用技巧(一)
2010/08/16 Javascript
JS获取文本框,下拉框,单选框的值的简单实例
2014/02/26 Javascript
javascript框架设计之类工厂
2015/06/23 Javascript
基于javascript实现浏览器滚动条快到底部时自动加载数据
2015/11/30 Javascript
jQuery获得字体颜色16位码的方法
2016/02/20 Javascript
微信jssdk在iframe页面失效问题的解决措施
2016/03/03 Javascript
jQuery实现ToolTip元素定位显示功能示例
2016/11/23 Javascript
JavaScript解析JSON格式数据的方法示例
2017/01/24 Javascript
Vue.js事件处理器与表单控件绑定详解
2017/03/20 Javascript
Angular4 中常用的指令入门总结
2017/06/12 Javascript
JS实现移动端按首字母检索城市列表附源码下载
2017/07/05 Javascript
浅谈JavaScript find 方法不支持IE的问题
2017/09/28 Javascript
ionic选择多张图片上传的示例代码
2017/10/10 Javascript
在JS循环中使用async/await的方法
2018/10/12 Javascript
微信小程序点餐系统开发常见问题汇总
2019/08/06 Javascript
Python中的深拷贝和浅拷贝详解
2015/06/03 Python
python中私有函数调用方法解密
2016/04/29 Python
Python文件与文件夹常见基本操作总结
2016/09/19 Python
python微信跳一跳系列之棋子定位颜色识别
2018/02/26 Python
python全栈要学什么 python全栈学习路线
2019/06/28 Python
Python通过VGG16模型实现图像风格转换操作详解
2020/01/16 Python
Python JSON编解码方式原理详解
2020/01/20 Python
使用Python FastAPI构建Web服务的实现
2020/06/08 Python
使用python修改文件并立即写回到原始位置操作(inplace读写)
2020/06/28 Python
完美解决IE8下不兼容rgba()的问题
2017/03/31 HTML / CSS
家居饰品店创业计划书
2014/01/31 职场文书
《彭德怀和他的大黑骡子》教学反思
2014/04/12 职场文书
布达拉宫的导游词
2015/02/02 职场文书
教师文明餐桌光盘行动倡议书
2015/04/28 职场文书
针对吵架老公保证书
2015/05/08 职场文书
公司员工宿舍管理制度
2015/08/03 职场文书
Java并发编程之详解CyclicBarrier线程同步
2021/06/23 Java/Android