在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python里对list中的整数求平均并排序
Sep 12 Python
python中MySQLdb模块用法实例
Nov 10 Python
利用Python2下载单张图片与爬取网页图片实例代码
Dec 25 Python
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
Apr 18 Python
使用python存储网页上的图片实例
May 22 Python
查看python下OpenCV版本的方法
Aug 03 Python
sklearn-SVC实现与类参数详解
Dec 10 Python
python实现在内存中读写str和二进制数据代码
Apr 24 Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 Python
解决pip安装tensorflow中出现的no module named tensorflow.python 问题方法
Feb 20 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
整理Python中常用的conda命令操作
Jun 15 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
防止本地用户用fsockopen DDOS攻击对策
2011/11/02 PHP
thinkphp关于简单的权限判定方法
2017/04/03 PHP
PHP基于DOMDocument解析和生成xml的方法分析
2017/07/17 PHP
PHP常用header头定义代码示例汇总
2020/08/29 PHP
Jquery 弹出层插件实现代码
2009/10/24 Javascript
node.js入门教程迷你书、node.js入门web应用开发完全示例
2014/04/06 Javascript
js检测输入内容全为空格的方法
2014/05/03 Javascript
Nodejs实现的一个简单udp广播服务器、客户端
2014/09/25 NodeJs
javascript中html字符串转化为jquery dom对象的方法
2015/08/27 Javascript
第三章之Bootstrap 表格与按钮功能
2016/04/25 Javascript
VUEJS实战之构建基础并渲染出列表(1)
2016/06/13 Javascript
javascript 动态样式添加的简单实现
2016/10/11 Javascript
JS中Select下拉列表类(支持输入模糊查询)功能
2017/01/17 Javascript
jQuery ajax请求struts action实现异步刷新
2017/04/19 jQuery
详解webpack解惑:require的五种用法
2017/06/09 Javascript
详解Node使用Puppeteer完成一次复杂的爬虫
2018/04/18 Javascript
微信小程序登录换取token的教程
2018/05/31 Javascript
在vue中根据光标的显示与消失实现下拉列表
2019/09/29 Javascript
Vue强制组件重新渲染的方法讨论
2020/02/03 Javascript
Python3.4实现从HTTP代理网站批量获取代理并筛选的方法示例
2017/09/26 Python
python bmp转换为jpg 并删除原图的方法
2018/10/25 Python
python 缺失值处理的方法(Imputation)
2019/07/02 Python
python中hasattr()、getattr()、setattr()函数的使用
2019/08/16 Python
Pyspark获取并处理RDD数据代码实例
2020/03/27 Python
python两种获取剪贴板内容的方法
2020/11/06 Python
CSS3实现的闪烁跳跃进度条示例(附源码)
2013/08/19 HTML / CSS
乐高积木玩具美国官网:LEGO Shop US
2016/09/16 全球购物
毕业求职自荐信格式是什么
2013/11/19 职场文书
租车协议书范本
2014/04/22 职场文书
高三霸气励志标语
2014/06/24 职场文书
永远跟党走演讲稿
2014/09/12 职场文书
个人融资协议书范本两则
2014/10/15 职场文书
2014年安全员工作总结
2014/11/13 职场文书
2014年路政工作总结
2014/12/10 职场文书
2015年保险公司工作总结
2015/04/24 职场文书
深入理解python协程
2021/06/15 Python