tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的各种排序算法代码
Mar 04 Python
python 多进程通信模块的简单实现
Feb 20 Python
Python格式化css文件的方法
Mar 10 Python
浅谈python实现Google翻译PDF,解决换行的问题
Nov 28 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
python如何实现不可变字典inmutabledict
Jan 08 Python
tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用
Jan 20 Python
python 日志 logging模块详细解析
Mar 31 Python
在python中求分布函数相关的包实例
Apr 15 Python
Pycharm github配置实现过程图解
Oct 13 Python
python3 使用ssh隧道连接mysql的操作
Dec 05 Python
python 基于DDT实现数据驱动测试
Feb 18 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
php学习之流程控制实现代码
2011/06/09 PHP
ajax取消挂起请求的处理方法
2013/03/18 PHP
PHP常用数组函数介绍
2014/07/28 PHP
PHP+MYSQL中文乱码问题
2015/07/01 PHP
Javascript 圆角div的实现代码
2009/10/15 Javascript
在Windows上安装Node.js模块的方法
2011/09/25 Javascript
使用js声明数组,对象在jsp页面中(获得ajax得到json数据)
2013/11/05 Javascript
javascript学习笔记(七)Ajax和Http状态码
2014/10/08 Javascript
WEB前端开发框架Bootstrap3 VS Foundation5
2016/05/16 Javascript
Node.js中常规的文件操作总结
2016/10/13 Javascript
Vue 固定头 固定列 点击表头可排序的表格组件
2016/11/25 Javascript
详解JS异步加载的三种方式
2017/03/07 Javascript
jquery请求servlet实现ajax异步请求的示例
2017/06/03 jQuery
jQuery Ajax 实现分页 kkpager插件实例代码
2017/08/10 jQuery
微信小程序中setInterval的使用方法
2017/09/29 Javascript
详解vue中axios的使用与封装
2019/03/20 Javascript
详解如何更好的使用module vuex
2019/03/27 Javascript
js设计模式之代理模式及订阅发布模式实例详解
2019/08/15 Javascript
Python实现读取文件最后n行的方法
2017/02/23 Python
python3+PyQt5使用数据库窗口视图
2018/04/24 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
2020/02/26 Python
pyinstaller打包单文件时--uac-admin选项不起作用怎么办
2020/04/15 Python
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
2020/05/15 Python
意大利顶级奢侈品电商:LUISAVIAROMA(支持中文)
2020/05/26 全球购物
用JAVA实现一种排序,JAVA类实现序列化的方法(二种)
2014/04/23 面试题
汽车促销活动方案
2014/03/31 职场文书
质量保证书范本
2014/04/29 职场文书
博士生导师推荐信
2014/07/08 职场文书
校园广播稿精选
2014/10/01 职场文书
大学生军训感言
2015/08/01 职场文书
心理健康教育主题班会
2015/08/13 职场文书
24句精辟的现实社会语录,句句扎心,道尽人性
2019/08/29 职场文书
python实现腾讯滑块验证码识别
2021/04/27 Python
golang用type-switch判断interface的实际存储类型
2022/04/14 Golang
讲解MySQL增删改操作
2022/05/06 MySQL
tree shaking对打包体积优化及作用
2022/07/07 Java/Android