pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
Python实现读取及写入csv文件的方法示例
Jan 12 Python
python实现求最长回文子串长度
Jan 22 Python
Python装饰器(decorator)定义与用法详解
Feb 09 Python
用Python实现读写锁的示例代码
Nov 05 Python
python批量获取html内body内容的实例
Jan 02 Python
python+openCV调用摄像头拍摄和处理图片的实现
Aug 06 Python
win10子系统python开发环境准备及kenlm和nltk的使用教程
Oct 14 Python
浅析python标准库中的glob
Mar 13 Python
Jupyter notebook 启动闪退问题的解决
Apr 13 Python
使用scrapy ImagesPipeline爬取图片资源的示例代码
Sep 28 Python
如何用python爬取微博热搜数据并保存
Feb 20 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
Zend Framework连接Mysql数据库实例分析
2016/03/19 PHP
php基于mcrypt_encrypt和mcrypt_decrypt实现字符串加密解密的方法
2016/07/12 PHP
PHP简单实现二维数组的矩阵转置操作示例
2017/11/24 PHP
jQuery html()等方法介绍
2009/11/18 Javascript
javascript 进阶篇3 Ajax 、JSON、 Prototype介绍
2012/03/14 Javascript
Jquery实现弹出层分享微博插件具备动画效果
2013/04/03 Javascript
使用js正则控制input标签只允许输入的值
2013/07/29 Javascript
node.js中的fs.chmod方法使用说明
2014/12/18 Javascript
jQuery中removeClass()方法用法实例
2015/01/05 Javascript
跟我学习javascript的函数调用和构造函数调用
2015/11/16 Javascript
jQuery实现点击水纹波动动画
2016/04/10 Javascript
seajs学习之模块的依赖加载及模块API的导出
2016/10/20 Javascript
ThinkPHP+jquery实现“加载更多”功能代码
2017/03/11 Javascript
基于vue的短信验证码倒计时demo
2017/09/13 Javascript
jQuery+ThinkPHP实现图片上传
2020/07/23 jQuery
js实现翻牌小游戏
2020/07/31 Javascript
JavaScript代码模拟鼠标自动点击事件示例
2020/08/07 Javascript
vue项目打包后提交到git上为什么没有dist这个文件的解决方法
2020/09/16 Javascript
JavaScript通如何过RGraph实现动态仪表盘
2020/10/15 Javascript
python计算方程式根的方法
2015/05/07 Python
Go语言基于Socket编写服务器端与客户端通信的实例
2016/02/19 Python
Django objects.all()、objects.get()与objects.filter()之间的区别介绍
2017/06/12 Python
OpenCV-Python实现轮廓检测实例分析
2018/01/05 Python
PyCharm鼠标右键不显示Run unittest的解决方法
2018/11/30 Python
python实现微信定时每天和女友发送消息
2019/04/29 Python
python 日期排序的实例代码
2019/07/11 Python
Pytorch 多块GPU的使用详解
2019/12/31 Python
Python dict的常用方法示例代码
2020/06/23 Python
手把手教你配置JupyterLab 环境的实现
2021/02/02 Python
CSS3点击按钮实现背景渐变动画效果
2016/10/19 HTML / CSS
html5使用canvas实现跟随光标跳动的火焰效果
2014/01/07 HTML / CSS
Linux管理员面试经常问道的相关命令
2014/12/12 面试题
2014年寒假社会实践活动心得体会
2014/04/07 职场文书
公司内部升职自荐信
2015/03/27 职场文书
恰同学少年观后感
2015/06/08 职场文书
新入职员工工作总结
2015/10/15 职场文书