pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现去除下载电影和电视剧文件名中的多余字符的方法
Sep 23 Python
状态机的概念和在Python下使用状态机的教程
Apr 11 Python
Python中利用原始套接字进行网络编程的示例
May 04 Python
在Python 3中实现类型检查器的简单方法
Jul 03 Python
利用python爬取散文网的文章实例教程
Jun 18 Python
使用python为mysql实现restful接口
Jan 05 Python
Python 查找字符在字符串中的位置实例
May 02 Python
Python3内置模块之base64编解码方法详解
Jul 13 Python
python实现多线程端口扫描
Aug 31 Python
Python代码生成视频的缩略图的实例讲解
Dec 22 Python
用Python自动清理电脑内重复文件,只要10行代码(自动脚本)
Jan 09 Python
Pillow图像处理库安装及使用
Apr 12 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
PHP 文章中的远程图片采集到本地的代码
2009/07/30 PHP
PHP使用Redis长连接的方法详解
2018/02/12 PHP
javascript iframe编程相关代码
2009/12/28 Javascript
javascript与CSS复习(三)
2010/06/29 Javascript
JavaScript中json对象和string对象之间相互转化
2012/12/26 Javascript
javascript中AJAX用法实例分析
2015/01/30 Javascript
JavaScript检查弹出窗口是否被阻拦的方法技巧
2015/03/13 Javascript
javascript正则表达式基础知识入门
2015/04/20 Javascript
javascript框架设计之浏览器的嗅探和特征侦测
2015/06/23 Javascript
jQuery插件HighCharts实现的2D条状图效果示例【附demo源码下载】
2017/03/15 Javascript
基于javascript中的typeof和类型判断(详解)
2017/10/27 Javascript
快速了解vue-cli 3.0 新特性
2018/02/28 Javascript
在vue中实现点击选择框阻止弹出层消失的方法
2018/09/15 Javascript
详解es6新增数组方法简便了哪些操作
2019/05/09 Javascript
Vue开发中遇到的跨域问题及解决方法
2020/02/11 Javascript
javascript-hashchange事件和历史状态管理实例分析
2020/04/18 Javascript
js实现上传按钮并显示缩略图小轮子
2020/05/04 Javascript
js实现验证码干扰(动态)
2021/02/23 Javascript
[01:00:25]NB vs Secret 2018国际邀请赛小组赛BO1 B组加赛 8.19
2018/08/21 DOTA
[01:10:58]KG vs TNC 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
仅利用30行Python代码来展示X算法
2015/04/01 Python
Python获取当前页面内所有链接的四种方法对比分析
2017/08/19 Python
Python最火、R极具潜力 2017机器学习调查报告
2017/12/11 Python
Python Dataframe 指定多列去重、求差集的方法
2018/07/10 Python
Python实现按逗号分隔列表的方法
2018/10/23 Python
pyside+pyqt实现鼠标右键菜单功能
2020/12/08 Python
Python使用Tkinter实现滚动抽奖器效果
2020/01/06 Python
python交互模式基础知识点学习
2020/06/18 Python
python 实现关联规则算法Apriori的示例
2020/09/30 Python
HTML5之SVG 2D入门7—SVG元素的重用与引用
2013/01/30 HTML / CSS
h5使用canvas画布实现手势解锁
2019/01/04 HTML / CSS
英国领先的男士美容护发用品公司:Mankind
2016/08/31 全球购物
New delete 与malloc free 的联系与区别
2013/02/04 面试题
2014年生活老师工作总结
2014/12/23 职场文书
以权谋私检举信范文
2015/03/02 职场文书
MySql子查询IN的执行和优化的实现
2021/08/02 MySQL