pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python装饰器验证配置文件示例
Feb 24 Python
Python collections模块实例讲解
Apr 07 Python
Python中的两个内置模块介绍
Apr 05 Python
python中私有函数调用方法解密
Apr 29 Python
Python实现动态图解析、合成与倒放
Jan 18 Python
Python中单例模式总结
Feb 20 Python
Python内置函数reversed()用法分析
Mar 20 Python
Python中数组,列表:冒号的灵活用法介绍(np数组,列表倒序)
Apr 18 Python
Python3实现的简单验证码识别功能示例
May 02 Python
python爬虫获取新浪新闻教学
Dec 23 Python
详解python中TCP协议中的粘包问题
Mar 22 Python
自动在Windows中运行Python脚本并定时触发功能实现
Sep 04 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
编译问题
2006/10/09 PHP
PHP字符串函数系列之nl2br(),在字符串中的每个新行 (\n) 之前插入 HTML 换行符br
2011/11/10 PHP
php获取数组长度的方法(有实例)
2013/10/27 PHP
举例讲解PHP面对对象编程的多态
2015/08/12 PHP
smarty简单应用实例
2015/11/03 PHP
PHP与服务器文件系统的简单交互
2016/10/21 PHP
PHP实现基于PDO扩展连接PostgreSQL对象关系数据库示例
2018/03/31 PHP
PHP vsprintf()函数格式化字符串操作原理解析
2020/07/14 PHP
用roll.js实现的图片自动滚动+鼠标触动的特效
2007/03/18 Javascript
jQuery 扩展对input的一些操作方法
2009/10/30 Javascript
javascript或asp实现的判断身份证号码是否正确两种验证方法
2009/11/26 Javascript
jquery动态加载图片数据练习代码
2011/08/04 Javascript
jQuery 获取URL的GET参数值的小例子
2013/04/18 Javascript
使用javascript实现判断当前浏览器
2015/04/14 Javascript
Javascript中For In语句用法实例
2015/05/14 Javascript
jqGrid中文文档之选项设置
2015/12/02 Javascript
JS判断当前页面是否在微信浏览器打开的方法
2015/12/08 Javascript
详解nodejs与javascript中的aes加密
2016/05/22 NodeJs
js动态添加的DIV中的onclick事件简单实例
2016/07/25 Javascript
JS开发中百度地图+城市联动实现实时触发查询地址功能
2017/04/13 Javascript
jquery tmpl模板(实例讲解)
2017/09/02 jQuery
自定义vue组件发布到npm的方法
2018/05/09 Javascript
react-native动态切换tab组件的方法
2018/07/07 Javascript
Vue头像处理方案小结
2018/07/26 Javascript
JS中判断字符串存在和非空的方法
2018/09/12 Javascript
手把手带你封装一个vue component第三方库
2019/02/14 Javascript
vue 返回上一页,页面样式错乱的解决
2019/11/14 Javascript
一篇超完整的Vue新手入门指导教程
2020/11/18 Vue.js
python base64 decode incorrect padding错误解决方法
2015/01/08 Python
python使用建议技巧分享(三)
2020/08/18 Python
使用python把xmind转换成excel测试用例的实现代码
2020/10/12 Python
python 基于PYMYSQL使用MYSQL数据库
2020/12/24 Python
python中K-means算法基础知识点
2021/01/25 Python
国际礼品店:GiftsnIdeas
2018/05/03 全球购物
汇智创新科技发展有限公司
2015/12/06 面试题
激励员工的口号
2014/06/16 职场文书