在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的五种异常处理机制介绍
Sep 02 Python
在windows系统中实现python3安装lxml
Mar 23 Python
python中print的不换行即时输出的快速解决方法
Jul 20 Python
利用Python自动监控网站并发送邮件告警的方法
Aug 24 Python
Python实现iOS自动化打包详解步骤
Oct 03 Python
python 3.3 下载固定链接文件并保存的方法
Dec 18 Python
Python语法分析之字符串格式化
Jun 13 Python
对PyQt5中的菜单栏和工具栏实例详解
Jun 20 Python
Python的log日志功能及设置方法
Jul 11 Python
python求一个字符串的所有排列的实现方法
Feb 04 Python
在python里创建一个任务(Task)实例
Apr 25 Python
在python中对于bool布尔值的取反操作
Dec 11 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
php入门教程 精简版
2009/12/13 PHP
解析PHP中VC6 X86和VC9 X86的区别及 Non Thread Safe的意思
2013/06/28 PHP
php利用反射实现插件机制的方法
2015/03/14 PHP
在WordPress的后台中添加顶级菜单和子菜单的函数详解
2016/01/11 PHP
PHP获取链表中倒数第K个节点的方法
2018/01/18 PHP
php爬取天猫和淘宝商品数据
2018/02/23 PHP
通过jquery实现tab标签浏览效果
2007/02/20 Javascript
javascript显示隐藏层比较不错的方法分析
2008/09/30 Javascript
基于jQuery的固定表格头部的代码(IE6,7,8测试通过)
2010/05/18 Javascript
Javascript实现仿WebQQ界面的“浮云”兼容 IE7以上版本及FF
2011/04/27 Javascript
基于jquery实现的一个选择中国大学的弹框 (数据、步骤、代码)
2012/07/26 Javascript
更快的异步执行(setTimeout多浏览器)
2014/08/12 Javascript
JavaScript实现的背景自动变色代码
2015/10/17 Javascript
详解如何在NodeJS项目中优雅的使用ES6
2017/04/22 NodeJs
微信小程序实现移动端滑动分页效果(ajax)
2017/06/13 Javascript
Django使用多数据库的方法
2017/09/06 Javascript
Vue中props的使用详解
2018/06/15 Javascript
使用Angular 6创建各种动画效果的方法
2018/10/10 Javascript
Vue.js 父子组件通信的十种方式
2018/10/30 Javascript
仿vue-cli搭建属于自己的脚手架的方法步骤
2019/04/17 Javascript
Python PyQt5实现的简易计算器功能示例
2017/08/23 Python
python机器学习之神经网络(二)
2017/12/20 Python
Python学习笔记之自定义函数用法详解
2019/06/08 Python
Python日志syslog使用原理详解
2020/02/18 Python
Python基于pip实现离线打包过程详解
2020/05/15 Python
Jupyter notebook快速入门教程(推荐)
2020/05/18 Python
python报错: 'list' object has no attribute 'shape'的解决
2020/07/15 Python
python logging模块的使用
2020/09/07 Python
Python用摘要算法生成token及检验token的示例代码
2020/12/01 Python
英国100%防污和防水的靴子:Muck Boot Company
2020/09/08 全球购物
Ego Shoes官网:英国时髦鞋类品牌
2020/10/19 全球购物
东京审判观后感
2015/06/01 职场文书
忆童年!用Python实现愤怒的小鸟游戏
2021/06/07 Python
Java实现学生管理系统(IO版)
2022/02/24 Java/Android
Docker部署Mysql8的实现步骤
2022/07/07 Servers
nginx sticky实现基于cookie负载均衡示例详解
2022/12/24 Servers