在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化测试实例解析
Sep 28 Python
python实现简单socket程序在两台电脑之间传输消息的方法
Mar 13 Python
numpy实现合并多维矩阵、list的扩展方法
May 08 Python
判断python字典中key是否存在的两种方法
Aug 10 Python
Python中collections模块的基本使用教程
Dec 07 Python
pandas分区间,算频率的实例
Jul 04 Python
Python 控制终端输出文字的实例
Jul 12 Python
Python实现最常见加密方式详解
Jul 13 Python
python中open函数的基本用法示例
Sep 07 Python
python 实现让字典的value 成为列表
Dec 16 Python
Python编程快速上手——Excel表格创建乘法表案例分析
Feb 28 Python
Python操作Jira库常用方法解析
Apr 10 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
B2K与车机的中波PK
2021/03/02 无线电
收集的PHP中与数组相关的函数
2007/03/22 PHP
php使用strip_tags()去除html标签仍有空白的解决方法
2016/07/28 PHP
ThinkPHP框架表单验证操作方法
2017/07/19 PHP
tp5框架无刷新分页实现方法分析
2019/09/26 PHP
使用Rancher在K8S上部署高性能PHP应用程序的教程
2020/07/10 PHP
仿服务器端脚本方式的JS模板实现方法
2007/04/27 Javascript
jquery 关键字“拖曳搜索”之“拖曳”以及 图片“提示自适应放大”效果 的实现
2010/04/18 Javascript
js监控IE火狐浏览器关闭、刷新、回退、前进事件
2014/07/23 Javascript
jQuery中siblings()方法用法实例
2015/01/08 Javascript
jquery的幻灯片图片切换效果代码分享
2015/09/07 Javascript
jquery表单验证插件formValidator使用方法
2016/04/01 Javascript
Spring MVC中Ajax实现二级联动的简单实例
2016/07/06 Javascript
jquery属性,遍历,HTML操作方法详解
2016/09/17 Javascript
Vue 实现拖动滑块验证功能(只有css+js没有后台验证步骤)
2018/08/24 Javascript
webpack4 配置 ssr 环境遇到“document is not defined”
2019/10/24 Javascript
Angular单元测试之事件触发的实现
2020/01/20 Javascript
基于Python实现通过微信搜索功能查看谁把你删除了
2016/01/27 Python
python中不能连接超时的问题及解决方法
2018/06/10 Python
python3+PyQt5 实现Rich文本的行编辑方法
2019/06/17 Python
python单向循环链表原理与实现方法示例
2019/12/03 Python
Python的赋值、深拷贝与浅拷贝的区别详解
2020/02/12 Python
django中嵌套的try-except实例
2020/05/21 Python
keras打印loss对权重的导数方式
2020/06/10 Python
Clearly澳大利亚:购买眼镜、太阳镜和隐形眼镜
2018/04/26 全球购物
介绍一下你对SOA的认识
2016/04/24 面试题
小学门卫岗位职责
2013/12/17 职场文书
工程售后服务承诺书
2014/05/21 职场文书
考试诚信承诺书
2014/05/23 职场文书
宿舍标语大全
2014/06/19 职场文书
高速铁道技术专业求职信
2014/08/09 职场文书
2014年初中班主任工作总结
2014/11/08 职场文书
停电调休通知
2015/04/16 职场文书
2015年会计人员工作总结
2015/05/22 职场文书
Nginx本地目录映射实现代码实例
2021/03/31 Servers
nginx配置虚拟主机的详细步骤
2021/07/21 Servers