pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动安装pip
Apr 24 Python
python字符串连接的N种方式总结
Sep 17 Python
Python手机号码归属地查询代码
May 04 Python
unittest+coverage单元测试代码覆盖操作实例详解
Apr 04 Python
django请求返回不同的类型图片json,xml,html的实例
May 22 Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 Python
python实现Flappy Bird源码
Dec 24 Python
python django model联合主键的例子
Aug 06 Python
Python日志处理模块logging用法解析
May 19 Python
python mock测试的示例
Oct 19 Python
python实现进度条的多种实现
Apr 29 Python
Python  序列化反序列化和异常处理的问题小结
Dec 24 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
如何在PHP中使用Oracle数据库(1)
2006/10/09 PHP
php调用dll的实例操作动画与代码分享
2012/08/14 PHP
php中的filesystem文件系统函数介绍及使用示例
2014/02/13 PHP
php中debug_backtrace、debug_print_backtrace和匿名函数用法实例
2014/12/01 PHP
php递归函数三种实现方法及如何实现数字累加
2015/08/07 PHP
PHP中set_include_path()函数相关用法分析
2016/07/18 PHP
php封装的验证码类分享
2017/02/26 PHP
PHP和MYSQL实现分页导航思路详解
2017/04/11 PHP
PHP图片水印类的封装
2017/07/06 PHP
PHP的cookie与session原理及用法详解
2019/09/27 PHP
一个简单的js鼠标划过切换效果
2010/06/30 Javascript
Ajax 数据请求的简单分析
2011/04/05 Javascript
autoPlay 基于jquery的图片自动播放效果
2011/12/07 Javascript
JQuery中使用Ajax赋值给全局变量异常的解决方法
2014/01/10 Javascript
jQuery的css() 方法使用指南
2015/05/03 Javascript
多种js图片预加载实现方式分享
2016/02/19 Javascript
自动完成的搜索框javascript实现
2016/02/26 Javascript
canvas简单快速的实现知乎登录页背景效果
2017/05/08 Javascript
微信小程序之页面拦截器的示例代码
2017/09/07 Javascript
JS+canvas动态绘制饼图的方法示例
2017/09/12 Javascript
详解Angular5/Angular6项目如何添加热更新(HMR)功能
2018/10/10 Javascript
Taro小程序自定义顶部导航栏功能的实现
2020/12/17 Javascript
python中关于日期时间处理的问答集锦
2013/03/08 Python
python中Pycharm 输出中文或打印中文乱码现象的解决办法
2017/06/16 Python
对pycharm 修改程序运行所需内存详解
2018/12/03 Python
Bodum官网:咖啡和茶壶、玻璃器皿、厨房电器等
2018/08/01 全球购物
荷兰照明、灯具和配件网上商店:dmlights
2019/08/25 全球购物
运动会开幕式邀请函
2014/01/22 职场文书
监察局领导班子四风问题整改措施思想汇报
2014/10/05 职场文书
最美乡村教师观后感
2015/06/11 职场文书
家庭聚会祝酒词
2015/08/11 职场文书
创业计划书之婴幼儿游泳馆
2019/09/11 职场文书
MySQL获取所有分类的前N条记录
2021/05/07 MySQL
python常见的占位符总结及用法
2021/07/02 Python
详解Vue项目的打包方式(生成dist文件)
2022/01/18 Vue.js
vue项目支付功能代码详解
2022/02/18 Vue.js