在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
收藏整理的一些Python常用方法和技巧
May 18 Python
探究python中open函数的使用
Mar 01 Python
python脚本监控docker容器
Apr 27 Python
用virtualenv建立多个Python独立虚拟开发环境
Jul 06 Python
python 中if else 语句的作用及示例代码
Mar 05 Python
numpy中的delete删除数组整行和整列的实例
May 09 Python
利用Python如何批量修改数据库执行Sql文件
Jul 29 Python
简单了解python变量的作用域
Jul 30 Python
python实现连续变量最优分箱详解--CART算法
Nov 22 Python
python基于plotly实现画饼状图代码实例
Dec 16 Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 Python
Python figure参数及subplot子图绘制代码
Apr 18 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
加速XP搜索功能堪比vista
2007/03/22 PHP
PHP的变量总结 新手推荐
2011/04/18 PHP
PHP简洁函数(PHP简单明了函数语法)
2012/06/10 PHP
jQuery获取json后使用zy_tmpl生成下拉菜单
2015/03/27 PHP
PHP中如何防止外部恶意提交调用ajax接口
2016/04/11 PHP
php微信公众平台开发之微信群发信息
2016/09/13 PHP
PHP扩展mcrypt实现的AES加密功能示例
2019/01/29 PHP
字符串的replace方法应用浅析
2011/12/06 Javascript
在服务端(Page.Write)调用自定义的JS方法详解
2013/08/09 Javascript
实例说明为什么不要行内使用javascript
2014/04/18 Javascript
node.js 中国天气预报 简单实现
2016/06/06 Javascript
backbone简介_动力节点Java学院整理
2017/07/14 Javascript
jQuery实现基本隐藏与显示效果的方法详解
2018/09/05 jQuery
nodejs 使用http进行post或get请求的实例(携带cookie)
2019/01/03 NodeJs
vue切换菜单取消未完成接口请求的案例
2020/11/13 Javascript
[02:07]DOTA2超级联赛专访BBC:难忘网吧超神经历
2013/06/09 DOTA
[01:41]DOTA2超级联赛专访YYF 称一辈子难忘TI2
2013/05/28 DOTA
[04:22]DOTA2大事件之护国神翼
2020/08/14 DOTA
Python实现的多线程http压力测试代码
2017/02/08 Python
Python探索之Metaclass初步了解
2017/10/28 Python
Python网络编程使用select实现socket全双工异步通信功能示例
2018/04/09 Python
Python从使用线程到使用async/await的深入讲解
2018/09/16 Python
Scrapy框架爬取西刺代理网免费高匿代理的实现代码
2019/02/22 Python
解决django同步数据库的时候app models表没有成功创建的问题
2019/08/09 Python
flask框架json数据的拿取和返回操作示例
2019/11/28 Python
详解python 破解网站反爬虫的两种简单方法
2020/02/09 Python
使用sklearn对多分类的每个类别进行指标评价操作
2020/06/11 Python
澳大利亚家具和家居用品在线:BROSA
2017/11/02 全球购物
玩具公司的创业计划书
2013/12/31 职场文书
公司部门司机岗位职责
2014/01/03 职场文书
军校大学生个人的自我评价
2014/02/17 职场文书
总经理工作职责范文
2014/03/14 职场文书
党员应该树立反腐倡廉的坚定意识思想汇报
2014/09/12 职场文书
个人党性分析材料
2014/12/19 职场文书
pandas中DataFrame检测重复值的实现
2021/05/26 Python
python cv2图像质量压缩的算法示例
2021/06/04 Python