在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取京东商城手机列表url实例代码
Dec 18 Python
列举Python中吸引人的一些特性
Apr 09 Python
python 读入多行数据的实例
Apr 19 Python
Python使用jsonpath-rw模块处理Json对象操作示例
Jul 31 Python
Python 比较文本相似性的方法(difflib,Levenshtein)
Oct 15 Python
Python 3.8新特征之asyncio REPL
May 28 Python
Python pip替换为阿里源的方法步骤
Jul 02 Python
Python用字典构建多级菜单功能
Jul 11 Python
Python树莓派学习笔记之UDP传输视频帧操作详解
Nov 15 Python
python环境搭建和pycharm的安装配置及汉化详细教程(零基础小白版)
Aug 19 Python
2020年10款优秀的Python第三方库,看看有你中意的吗?
Jan 12 Python
opencv-python图像配准(匹配和叠加)的实现
Jun 23 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
php 什么是PEAR?
2009/03/19 PHP
PHP Session机制简介及用法
2014/08/19 PHP
smarty简单入门实例
2014/11/28 PHP
ThinkPHP自动完成中使用函数与回调方法实例
2014/11/29 PHP
PHP内核探索:哈希表碰撞攻击原理
2015/07/31 PHP
PHP添加文字水印或图片水印的水印类完整源代码与使用示例
2019/03/18 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
再论Javascript的类继承
2011/03/05 Javascript
jquery单行文字向上滚动效果示例
2014/03/06 Javascript
jQuery-1.9.1源码分析系列(十)事件系统之事件体系结构
2015/11/19 Javascript
jQuery实现摸拟alert提示框
2016/05/22 Javascript
angularjs使用directive实现分页组件的示例
2017/02/07 Javascript
node.js中grunt和gulp的区别详解
2017/07/17 Javascript
JS实现的加减乘除四则运算计算器示例
2017/08/09 Javascript
教你如何编写Vue.js的单元测试的方法
2018/10/17 Javascript
原生js实现each方法实例代码详解
2019/05/27 Javascript
Vue编写可显示周和月模式的日历 Vue自定义日历内容的显示
2019/06/26 Javascript
Javascript类型判断相关例题及解析
2020/08/26 Javascript
[01:29]Ti4循环赛第三日精彩回顾
2014/07/13 DOTA
Python常见异常分类与处理方法
2017/06/04 Python
NetworkX之Prim算法(实例讲解)
2017/12/22 Python
Python实现备份MySQL数据库的方法示例
2018/01/11 Python
利用python 更新ssh 远程代码 操作远程服务器的实现代码
2018/02/08 Python
python环境路径配置以及命令行运行脚本
2019/04/02 Python
Python批量查询关键词微信指数实例方法
2019/06/27 Python
python multiprocessing多进程变量共享与加锁的实现
2019/10/02 Python
Python网络爬虫四大选择器用法原理总结
2020/06/01 Python
python实现图像外边界跟踪操作
2020/07/13 Python
Python timeit模块原理及使用方法
2020/10/10 Python
css3动画过渡实现鼠标跟随导航效果
2018/02/08 HTML / CSS
用CSS3写的模仿iPhone中的返回按钮
2015/04/04 HTML / CSS
canvas版人体时钟的实现示例
2021/01/29 HTML / CSS
优秀护士获奖感言
2014/02/20 职场文书
2014年学生资助工作总结
2014/12/18 职场文书
莫言获奖感言(全文)
2015/07/31 职场文书
2016入党积极分子党校培训心得体会
2016/01/06 职场文书