pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python文档生成工具pydoc使用介绍
Jun 02 Python
python pandas实现excel转为html格式的方法
Oct 23 Python
python学习--使用QQ邮箱发送邮件代码实例
Apr 16 Python
pyqt5 tablewidget 利用线程动态刷新数据的方法
Jun 17 Python
python实现比较类的两个instance(对象)是否相等的方法分析
Jun 26 Python
详解将Pandas中的DataFrame类型转换成Numpy中array类型的三种方法
Jul 06 Python
python执行scp命令拷贝文件及文件夹到远程主机的目录方法
Jul 08 Python
pyqt5、qtdesigner安装和环境设置教程
Sep 25 Python
Python3实现二叉树的最大深度
Sep 30 Python
calendar在python3时间中常用函数举例详解
Nov 18 Python
使用numpy实现矩阵的翻转(flip)与旋转
Jun 03 Python
使用scrapy实现增量式爬取方式
Jun 21 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
《忧国的莫里亚蒂》先导宣传图与STAFF公开
2020/03/04 日漫
学习使用PHP数组
2006/10/09 PHP
浅谈php错误提示及查错方法
2015/07/14 PHP
PHP+Ajax无刷新带进度条图片上传示例
2017/02/08 PHP
浅析PHP数据导出知识点
2018/02/17 PHP
统一接口:为FireFox添加IE的方法和属性的js代码
2007/03/25 Javascript
div拖拽插件——JQ.MoveBox.js(自制JQ插件)
2013/05/17 Javascript
JS画5角星方法介绍
2013/09/17 Javascript
JS 操作Array数组的方法及属性实例解析
2014/01/08 Javascript
无刷新预览所选择的图片示例代码
2014/04/02 Javascript
jQuery ajax调用WCF服务实例
2014/07/16 Javascript
轻量级网页遮罩层jQuery插件用法实例
2015/07/31 Javascript
jQuery实现点击按钮弹出可关闭层的浮动层插件
2015/09/19 Javascript
jQuery定义插件的方法
2015/12/18 Javascript
基于CSS3和jQuery实现跟随鼠标方位的Hover特效
2016/07/25 Javascript
javascript数组常用方法汇总
2016/09/10 Javascript
Bootstrap fileinput组件封装及使用详解
2017/03/10 Javascript
B/S(Web)实时通讯解决方案分享
2017/04/06 Javascript
Angular简单验证功能示例
2017/12/22 Javascript
JS中offset和匀速动画详解
2018/02/06 Javascript
原生js实现拖拽功能基本思路详解
2018/04/18 Javascript
[07:20]2014DOTA2西雅图国际邀请赛 选手讲解积分赛第二天
2014/07/11 DOTA
[53:15]Newbee vs Pain 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
python基于windows平台锁定键盘输入的方法
2015/03/05 Python
pycharm下打开、执行并调试scrapy爬虫程序的方法
2017/11/29 Python
Python查找两个有序列表中位数的方法【基于归并算法】
2018/04/20 Python
Python if语句知识点用法总结
2018/06/10 Python
彻彻底底地理解Python中的编码问题
2018/10/15 Python
解决pycharm 误删掉项目文件的处理方法
2018/10/22 Python
python整合ffmpeg实现视频文件的批量转换
2019/05/31 Python
详解程序意外中断自动重启shell脚本(以Python为例)
2019/07/26 Python
简单易懂Pytorch实战实例VGG深度网络
2019/08/27 Python
房地产出纳岗位职责
2013/12/01 职场文书
战友聚会邀请函
2014/01/18 职场文书
2014年安全生产目标责任书
2014/07/23 职场文书
2016党员党课心得体会
2016/01/07 职场文书