在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python学习资料
Feb 08 Python
python实现端口转发器的方法
Mar 13 Python
Python 变量的创建过程详解
Sep 02 Python
Django项目创建到启动详解(最全最详细)
Sep 07 Python
python获取网络图片方法及整理过程详解
Dec 20 Python
Python调用.NET库的方法步骤
Dec 27 Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 Python
Python写捕鱼达人的游戏实现
Mar 31 Python
python实现PDF中表格转化为Excel的方法
Jun 16 Python
如何通过python计算圆周率PI
Nov 11 Python
python读取图片颜色值并生成excel像素画的方法实例
Feb 19 Python
python解决OpenCV在读取显示图片的时候闪退的问题
Feb 23 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
解析func_num_args与func_get_args函数的使用
2013/06/24 PHP
PHP设计模式之装饰器模式定义与用法详解
2018/04/02 PHP
PHP registerXPathNamespace()函数讲解
2019/02/03 PHP
PHP XML Expat解析器知识点总结
2019/02/15 PHP
使用PHP开发留言板功能
2019/11/19 PHP
解决表单中第一个非隐藏的元素获得焦点的一个方案
2009/10/26 Javascript
button没写type=button会导致点击时提交
2014/03/06 Javascript
js获取当前日期时间及其它操作汇总
2015/04/17 Javascript
js实现商城星星评分的效果
2015/12/29 Javascript
AngularJS应用开发思维之依赖注入3
2016/08/19 Javascript
JS关闭窗口时产生的事件及用法示例
2016/08/20 Javascript
Angular.JS学习之依赖注入$injector详析
2016/10/20 Javascript
jQuery导航条固定定位效果实例代码
2017/05/26 jQuery
详解angular脏检查原理及伪代码实现
2018/06/08 Javascript
Vue项目history模式下微信分享爬坑总结
2019/03/29 Javascript
vue-form表单验证是否为空值的实例详解
2019/10/29 Javascript
javascript实现下拉菜单效果
2021/02/09 Javascript
Python 爬虫的工具列表大全
2016/01/31 Python
对pandas中两种数据类型Series和DataFrame的区别详解
2018/11/12 Python
python写程序统计词频的方法
2019/07/29 Python
Python目录和文件处理总结详解
2019/09/02 Python
Python 将 QQ 好友头像生成祝福语的实现代码
2020/05/03 Python
Python urllib2运行过程原理解析
2020/06/04 Python
Keras 使用 Lambda层详解
2020/06/10 Python
使用keras内置的模型进行图片预测实例
2020/06/17 Python
OpenCV实现机器人对物体进行移动跟随的方法实例
2020/11/09 Python
用CSS禁用输入法(CSS3 UI规范)实例解析
2012/12/04 HTML / CSS
基于HTML5的WebGL经典3D虚拟机房漫游动画
2017/11/15 HTML / CSS
La Redoute英国官网:法国时尚品牌
2017/04/27 全球购物
日本著名的服饰鞋帽综合类购物网站:MAGASEEK
2019/01/09 全球购物
Nanushka官网:匈牙利服装品牌
2019/08/14 全球购物
Solaris操作系统的线程机制
2012/12/23 面试题
安全教育实施方案
2014/03/02 职场文书
护林员个人总结
2015/03/04 职场文书
总经理年会致辞
2015/07/29 职场文书
事业单位岗位说明书
2015/10/08 职场文书