在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
讲解python参数和作用域的使用
Nov 01 Python
python检查序列seq是否含有aset中项的方法
Jun 30 Python
Python获取某一天是星期几的方法示例
Jan 17 Python
利用Python爬取微博数据生成词云图片实例代码
Aug 31 Python
浅析Python数据处理
May 02 Python
Python 利用scrapy爬虫通过短短50行代码下载整站短视频
Oct 29 Python
对python同一个文件夹里面不同.py文件的交叉引用方法详解
Dec 15 Python
python实现在遍历列表时,直接对dict元素增加字段的方法
Jan 15 Python
利用python实现凯撒密码加解密功能
Mar 31 Python
python调用API接口实现登陆短信验证
May 10 Python
python小白切忌乱用表达式
May 29 Python
快速创建python 虚拟环境
Nov 28 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
php环境配置之CGI、FastCGI、PHP-CGI、PHP-FPM、Spawn-FCGI比较?
2011/10/17 PHP
php通过strpos查找字符串出现位置的方法
2015/03/17 PHP
php操作MongoDB类实例
2015/06/17 PHP
用JQuery在网页中实现分隔条功能的代码
2012/08/09 Javascript
javascript获取ckeditor编辑器的值(实现代码)
2013/11/18 Javascript
document.forms[].submit()使用介绍
2014/02/19 Javascript
jQuery垂直多级导航菜单代码分享
2015/08/18 Javascript
jquery表单验证插件formValidator使用方法
2016/04/01 Javascript
JavaScript通过filereader接口读取文件
2017/05/10 Javascript
Vue单文件组件的如何使用方式介绍
2017/07/28 Javascript
vue+element实现批量删除功能的示例
2018/02/28 Javascript
vsCode安装使用教程和插件安装方法
2020/08/24 Javascript
JS判断两个数组或对象是否相同的方法示例
2019/02/28 Javascript
如何解决日期函数new Date()浏览器兼容性问题
2019/09/11 Javascript
js仿360开机效果
2019/12/26 Javascript
精读《Vue3.0 Function API》
2020/05/20 Javascript
Vue CLI4 Vue.config.js标准配置(最全注释)
2020/06/05 Javascript
实践Python的爬虫框架Scrapy来抓取豆瓣电影TOP250
2016/01/20 Python
使用Python对SQLite数据库操作
2017/04/06 Python
Python中标准库OS的常用方法总结大全
2017/07/19 Python
Python 中的Selenium异常处理实例代码
2018/05/03 Python
python+opencv实现阈值分割
2018/12/26 Python
Python列表元素常见操作简单示例
2019/10/25 Python
flask 框架操作MySQL数据库简单示例
2020/02/02 Python
使用python求斐波那契数列中第n个数的值示例代码
2020/07/26 Python
HTML5的新特性(1)
2016/03/03 HTML / CSS
澳大利亚鞋仓库:Shoe Warehouse
2019/07/25 全球购物
美国和加拿大计算机和电子产品购物网站:TigerDirect.com
2019/09/13 全球购物
ASICS印度官方网站:日本专业运动品牌
2020/06/20 全球购物
自荐信的禁忌和要点
2013/10/15 职场文书
年会主持词结束语
2014/03/27 职场文书
2014年医院十一国庆节活动方案
2014/09/15 职场文书
教师作风整顿个人剖析材料
2014/10/10 职场文书
人事任命通知书
2015/04/21 职场文书
科普 | 业余无线电知识-波段篇
2022/02/18 无线电
Nginx+Windows搭建域名访问环境的操作方法
2022/03/17 Servers