在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门_浅谈数据结构的4种基本类型
May 16 Python
windows下Python实现将pdf文件转化为png格式图片的方法
Jul 21 Python
Window10+Python3.5安装opencv的教程推荐
Apr 02 Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 Python
基于python实现学生信息管理系统
Nov 22 Python
python实现用类读取文件数据并计算矩形面积
Jan 18 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
Apr 07 Python
python集合能干吗
Jul 19 Python
基于python实现简单C/S模式代码实例
Sep 14 Python
Django+Django-Celery+Celery的整合实战
Jan 20 Python
只用20行Python代码实现屏幕录制功能
Jun 02 Python
Python可视化学习之seaborn调色盘
Feb 24 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
php自动跳转中英文页面
2008/07/29 PHP
php 智能404跳转代码,适合换域名没改变目录的网站
2010/06/04 PHP
自己在做项目过程中学到的PHP知识收集
2012/08/20 PHP
真正根据utf8编码的规律来进行截取字符串的函数(utf8版sub_str )
2012/10/24 PHP
php中使用url传递数组的方法
2015/02/11 PHP
PHP封装CURL扩展类实例
2015/07/28 PHP
CI框架支持$_GET的两种实现方法
2016/05/18 PHP
浅谈PHP Cookie处理函数
2016/06/10 PHP
php文件上传原理与实现方法详解
2019/12/20 PHP
一个小型js框架myJSFrame附API使用帮助
2008/06/28 Javascript
IE与FireFox中的childNodes区别
2011/10/20 Javascript
javascript关于运动的各种问题经典总结
2015/04/27 Javascript
js学习阶段总结(必看篇)
2016/06/16 Javascript
javascript日期比较方法实例分析
2016/06/17 Javascript
vue开发心得和技巧分享
2016/10/27 Javascript
angular双向绑定模拟探索
2016/12/26 Javascript
关于axios返回空对象的问题解决
2017/04/04 Javascript
JQuery和html+css实现带小圆点和左右按钮的轮播图实例
2017/07/22 jQuery
SeaJS中use函数用法实例分析
2017/10/10 Javascript
AngularJS与BootStrap模仿百度分页的示例代码
2018/05/23 Javascript
Vue动态组件和异步组件原理详解
2019/05/06 Javascript
vue+django实现一对一聊天功能的实例代码
2019/07/17 Javascript
JS实现商品橱窗特效
2020/01/09 Javascript
Vue 实现监听窗口关闭事件,并在窗口关闭前发送请求
2020/09/01 Javascript
Python标准库之多进程(multiprocessing包)介绍
2014/11/25 Python
Python使用Flask框架获取当前查询参数的方法
2015/03/21 Python
python中根据字符串调用函数的实现方法
2016/06/12 Python
Python的shutil模块中文件的复制操作函数详解
2016/07/05 Python
django框架F&Q 聚合与分组操作示例
2019/12/12 Python
python百行代码自制电脑端网速悬浮窗的实现
2020/05/12 Python
pycharm导入源码的具体步骤
2020/08/04 Python
英国最大的手表网站:The Watch Hut
2017/03/31 全球购物
澳大利亚和新西兰最大的在线旅行社之一:Aunt Betty
2019/08/07 全球购物
奢华的意大利皮革手袋:Bene Handbags
2019/10/29 全球购物
香港士多网上超级市场:Ztore
2021/01/09 全球购物
三年级数学教学反思
2014/01/31 职场文书