pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Tkinter GUI编程入门介绍
Mar 10 Python
python实现分页效果
Oct 25 Python
python实现可视化动态CPU性能监控
Jun 21 Python
wtfPython—Python中一组有趣微妙的代码【收藏】
Aug 31 Python
python地震数据可视化详解
Jun 18 Python
利用Python模拟登录pastebin.com的实现方法
Jul 12 Python
Python 操作 ElasticSearch的完整代码
Aug 04 Python
Python字典中的值为列表或字典的构造实例
Dec 16 Python
Pandas时间序列:时期(period)及其算术运算详解
Feb 25 Python
python实现扫雷游戏
Mar 03 Python
python与idea的集成的实现
Nov 20 Python
Selenium Webdriver元素定位的八种常用方式(小结)
Jan 13 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
PHP 中英文混合排版中处理字符串常用的函数
2007/04/12 PHP
php设计模式 Facade(外观模式)
2011/06/26 PHP
解析PHP处理换行符的问题 \r\n
2013/06/13 PHP
解析posix与perl标准的正则表达式区别
2013/06/17 PHP
PHP输出两个数字中间有多少个回文数的方法
2015/03/23 PHP
php生成图片缩略图的方法
2015/04/07 PHP
PHP汉字转换拼音的函数代码
2015/12/30 PHP
TNC vs BOOM BO3 第三场2.13
2021/03/10 DOTA
jqPlot 基于jquery的画图插件
2011/04/26 Javascript
jquery 图片上传按比例预览插件集合
2011/05/28 Javascript
window.open以post方式将内容提交到新窗口
2012/12/26 Javascript
js比较日期大小的方法
2015/05/12 Javascript
纯javascript制作日历控件
2015/07/17 Javascript
jquery移动端TAB触屏切换实现效果
2020/12/22 Javascript
jQuery之简单的表单验证实例
2016/07/07 Javascript
浅析JavaScript中break、continue和return的区别
2016/11/30 Javascript
javascript数据结构中栈的应用之符号平衡问题
2017/04/11 Javascript
微信小程序-getUserInfo回调的实例详解
2017/10/27 Javascript
分享vue里swiper的一些坑
2018/08/30 Javascript
Vue项目报错:Uncaught SyntaxError: Unexpected token
2018/11/10 Javascript
vue指令v-html使用过滤器filters功能实例
2019/10/25 Javascript
jQuery实现动态加载瀑布流
2020/09/01 jQuery
Vue + ts实现轮播插件的示例
2020/11/10 Javascript
[51:00]Secret vs VGJ.S 2018国际邀请赛淘汰赛BO3 第一场 8.24
2018/08/25 DOTA
python字符串str和字节数组相互转化方法
2017/03/18 Python
Python实现爬虫从网络上下载文档的实例代码
2018/06/13 Python
python提取包含关键字的整行数据方法
2018/12/11 Python
解决django服务器重启端口被占用的问题
2019/07/26 Python
Pytorch加载部分预训练模型的参数实例
2019/08/18 Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
2020/06/11 Python
Python pip 常用命令汇总
2020/10/19 Python
给水排水工程专业毕业生推荐信
2013/10/28 职场文书
工程质量承诺书范文
2014/03/27 职场文书
销售合作意向书范本
2015/05/08 职场文书
纪检干部学习心得体会
2016/01/23 职场文书
Java Optional<Foo>转换成List<Bar>的实例方法
2021/06/20 Java/Android