tensorflow 获取变量&打印权值的实例讲解


Posted in Python onJune 14, 2018

在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的。

比如用tensorflow的slim库时:

<span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None):</span>
<span style="font-size:14px;"> """Create a resnet style transfer block.</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Args:</span>
<span style="font-size:14px;"> images: [batch-size, height, width, channels] image tensor to feed as input</span>
<span style="font-size:14px;"> output_shape: output image shape in form [height, width, channels]</span>
<span style="font-size:14px;"> hparams: hparams objects</span>
<span style="font-size:14px;"> scope: Variable scope</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> Returns:</span>
<span style="font-size:14px;"> Images after processing with resnet blocks.</span>
<span style="font-size:14px;"> """</span>
<span style="font-size:14px;"> end_points = {}</span>
<span style="font-size:14px;"> if hparams.noise_channel:</span>
<span style="font-size:14px;"> # separate the noise for visualization</span>
<span style="font-size:14px;"> end_points['noise'] = images[:, :, :, -1]</span>
<span style="font-size:14px;"> assert images.shape.as_list()[1:3] == output_shape[0:2]</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;"> with tf.variable_scope(scope, 'resnet_style_transfer', [images]):</span>
<span style="font-size:14px;"> with slim.arg_scope(</span>
<span style="font-size:14px;">  [slim.conv2d],</span>
<span style="font-size:14px;">  normalizer_fn=slim.batch_norm,</span>
<span style="font-size:14px;">  kernel_size=[hparams.generator_kernel_size] * 2,</span>
<span style="font-size:14px;">  stride=1):</span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   images,</span>
<span style="font-size:14px;">   hparams.resnet_filters,</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.relu)</span>
<span style="font-size:14px;">  for block in range(hparams.resnet_blocks):</span>
<span style="font-size:14px;">  net = resnet_block(net, hparams)</span>
<span style="font-size:14px;">  end_points['resnet_block_{}'.format(block)] = net</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">  net = slim.conv2d(</span>
<span style="font-size:14px;">   net,</span>
<span style="font-size:14px;">   output_shape[-1],</span>
<span style="font-size:14px;">   kernel_size=[1, 1],</span>
<span style="font-size:14px;">   normalizer_fn=None,</span>
<span style="font-size:14px;">   activation_fn=tf.nn.tanh,</span>
<span style="font-size:14px;">   scope='conv_out')</span>
<span style="font-size:14px;">  end_points['transferred_images'] = net</span>
<span style="font-size:14px;"> return net, end_points</span>

我们希望获取第一个卷积层的权重weight,该怎么办呢??

在训练时,这些可训练的变量会被tensorflow保存在 tf.trainable_variables() 中,于是我们就可以通过打印 tf.trainable_variables() 来获取该卷积层的名称(或者你也可以自己根据scope来看出来该变量的name ),然后利用tf.get_default_grap().get_tensor_by_name 来获取该变量。

举个简单的例子:

<span style="font-size:14px;">import tensorflow as tf</span>
<span style="font-size:14px;">with tf.variable_scope("generate"):</span>
<span style="font-size:14px;"> with tf.variable_scope("resnet_stack"):</span>
<span style="font-size:14px;">  #简单起见,这里没有用第三方库来说明,</span>
<span style="font-size:14px;">  bias = tf.Variable(0.0,name="bias")</span>
<span style="font-size:14px;">  weight = tf.Variable(0.0,name="weight")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">for tv in tf.trainable_variables():</span>
<span style="font-size:14px;"> print (tv.name)</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")</span>
<span style="font-size:14px;">w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")</span>
<span style="font-size:14px;"></span>
<span style="font-size:14px;">with tf.Session() as sess:</span>
<span style="font-size:14px;"> tf.global_variables_initializer().run()</span>
<span style="font-size:14px;"> print(sess.run(b))</span>
<span style="font-size:14px;"> print(sess.run(w))
</span>

结果如下:

tensorflow 获取变量&amp;打印权值的实例讲解

以上这篇tensorflow 获取变量&打印权值的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python整型运算之布尔型、标准整型、长整型操作示例
Jul 21 Python
Python 绘图库 Matplotlib 入门教程
Apr 19 Python
解决pycharm界面不能显示中文的问题
May 23 Python
tensorflow 获取变量&amp;打印权值的实例讲解
Jun 14 Python
python format 格式化输出方法
Jul 16 Python
使用Python实现画一个中国地图
Nov 23 Python
使用python写一个自动浏览文章的脚本实例
Dec 05 Python
使用TensorFlow-Slim进行图像分类的实现
Dec 31 Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 Python
Python包和模块的分发详细介绍
Jun 19 Python
python多线程和多进程关系详解
Dec 14 Python
python Tkinter的简单入门教程
Apr 11 Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
详解python之协程gevent模块
Jun 14 #Python
python 筛选数据集中列中value长度大于20的数据集方法
Jun 14 #Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
Jun 13 #Python
tensorflow: 查看 tensor详细数值方法
Jun 13 #Python
终端命令查看TensorFlow版本号及路径的方法
Jun 13 #Python
You might like
php smarty 二级分类代码和模版循环例子
2011/06/01 PHP
php 转换字符串编码 iconv与mb_convert_encoding的区别说明
2011/11/10 PHP
php导入导出excel实例
2013/10/25 PHP
php swoft框架实例用法
2020/12/22 PHP
Javascript 中介者模式实例
2009/12/16 Javascript
有道JavaScript监听浏览器的问题
2010/06/23 Javascript
node中socket.io的事件使用详解
2014/12/15 Javascript
JavaScript将字符串转换成字符编码列表的方法
2015/03/19 Javascript
BootStrap智能表单demo示例详解
2016/06/13 Javascript
微信小程序商城项目之购物数量加减(3)
2017/04/17 Javascript
jQuery实现的简单在线计算器功能
2017/05/11 jQuery
Angular获取手机验证码实现移动端登录注册功能
2017/05/17 Javascript
微信小程序实现人脸识别
2018/05/25 Javascript
详解webpack import()动态加载模块踩坑
2018/07/17 Javascript
vue+iview/elementUi实现城市多选
2019/03/28 Javascript
详解一个小实例理解js原型和继承
2019/04/24 Javascript
js实现类似iphone的网页滑屏解锁功能示例【附源码下载】
2019/06/10 Javascript
在node环境下parse Smarty模板的使用示例代码
2019/11/15 Javascript
python实现rest请求api示例
2014/04/22 Python
python实现中文输出的两种方法
2015/05/09 Python
解决django同步数据库的时候app models表没有成功创建的问题
2019/08/09 Python
Python中zip()函数的简单用法举例
2019/09/02 Python
采用专利算法搜索最廉价的机票:CheapAir
2016/09/10 全球购物
Omio中国:全欧洲低价大巴、火车和航班搜索和比价
2018/08/09 全球购物
美国礼品卡交易网站:Cardpool
2018/08/27 全球购物
泰国时尚电商:POMELO Fashion
2020/03/11 全球购物
LUISAVIAROMA中国官网:时尚奢侈品牌购物网站
2020/11/01 全球购物
教师岗位职责
2013/11/17 职场文书
装饰活动策划方案
2014/02/11 职场文书
青年志愿者先进事迹
2014/05/06 职场文书
申论倡议书范文
2014/05/13 职场文书
工厂清洁工岗位职责
2015/02/14 职场文书
缅怀先烈主题班会
2015/08/14 职场文书
商业计划书范文
2019/04/24 职场文书
Nginx+SpringBoot实现负载均衡的示例
2021/03/31 Servers
在项目中使用redis做缓存的一些思路
2021/09/14 Redis