pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 实现「食行生鲜」签到领积分功能
Sep 26 Python
Python补齐字符串长度的实例
Nov 15 Python
利用Pycharm断点调试Python程序的方法
Nov 29 Python
对python同一个文件夹里面不同.py文件的交叉引用方法详解
Dec 15 Python
Python生成器的使用方法和示例代码
Mar 04 Python
Python使用sklearn库实现的各种分类算法简单应用小结
Jul 04 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 Python
利用Python脚本实现自动刷网课
Feb 03 Python
Python文本文件的合并操作方法代码实例
Mar 31 Python
python对指定字符串逆序的6种方法(小结)
Apr 02 Python
python打开音乐文件的实例方法
Jul 21 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
Nov 20 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
一个php作的文本留言本的例子(三)
2006/10/09 PHP
用Php实现链结人气统计
2006/10/09 PHP
php入门学习知识点四 PHP正则表达式基本应用
2011/07/14 PHP
PHP5中GD库生成图形验证码(有汉字)
2013/07/28 PHP
php fread读取文件注意事项
2016/09/24 PHP
cakephp2.X多表联合查询join及使用分页查询的方法
2017/02/23 PHP
使用jQuery.wechat构建微信WEB应用
2014/10/09 Javascript
JavaScript中提前声明变量或函数例子
2014/11/12 Javascript
Ajax中解析Json的两种方法对比分析
2015/06/25 Javascript
jqGrid中文文档之选项设置
2015/12/02 Javascript
jQuery选择器用法实例详解
2015/12/17 Javascript
jQuery Mobile框架中的表单组件基础使用教程
2016/05/17 Javascript
为jQuery-easyui的tab组件添加右键菜单功能的简单实例
2016/10/10 Javascript
关于Function中的bind()示例详解
2016/12/02 Javascript
jQuery实现别踩白块儿网页版小游戏
2017/01/18 Javascript
Vue自定义指令详解
2017/07/28 Javascript
angular.js4使用 RxJS 处理多个 Http 请求
2017/09/23 Javascript
three.js中文文档学习之通过模块导入
2017/11/20 Javascript
微信小程序实现的一键连接wifi功能示例
2019/04/24 Javascript
Python实现类似jQuery使用中的链式调用的示例
2016/06/16 Python
python中lambda()的用法
2017/11/16 Python
Python机器学习logistic回归代码解析
2018/01/17 Python
Python实现高斯函数的三维显示方法
2018/12/29 Python
对python中基于tcp协议的通信(数据传输)实例讲解
2019/07/22 Python
Python中PyQt5/PySide2的按钮控件使用实例
2019/08/17 Python
python用pip install时安装失败的一系列问题及解决方法
2020/02/24 Python
Python requests模块安装及使用教程图解
2020/06/30 Python
使用css3和jquery实现可伸缩搜索框
2014/02/12 HTML / CSS
20佳惊艳的HTML5应用程序示例分享
2011/05/03 HTML / CSS
澳大利亚游乐场设备品牌:Lifespan Kids
2019/05/24 全球购物
年度考核自我评价
2014/01/25 职场文书
债务纠纷委托书
2014/08/30 职场文书
如何把新闻人物写得立体、鲜活?
2019/08/14 职场文书
配置nginx 重定向到系统维护页面
2021/06/08 Servers
Redis缓存-序列化对象存储乱码问题的解决
2021/06/21 Redis
Redis分布式锁的7种实现
2022/04/01 Redis