pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作列表之List.pop()方法的使用
May 21 Python
实例探究Python以并发方式编写高性能端口扫描器的方法
Jun 14 Python
Python中关键字nonlocal和global的声明与解析
Mar 12 Python
python、java等哪一门编程语言适合人工智能?
Nov 13 Python
几种实用的pythonic语法实例代码
Feb 24 Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 Python
Python3.5内置模块之shelve模块、xml模块、configparser模块、hashlib、hmac模块用法分析
Apr 27 Python
Python实现字典按key或者value进行排序操作示例【sorted】
May 03 Python
Python+Redis实现布隆过滤器
Dec 08 Python
opencv之为图像添加边界的方法示例
Dec 26 Python
Python使用Turtle模块绘制国旗的方法示例
Feb 28 Python
python APScheduler执行定时任务介绍
Apr 19 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
盘点被央视点名过的日本动画电影 一部比一部强
2020/03/08 日漫
php中对xml读取的相关函数的介绍一
2008/06/05 PHP
黑夜路人出的几道php笔试题
2009/08/04 PHP
php学习笔记之基础知识
2014/11/08 PHP
CodeIgniter配置之config.php用法实例分析
2016/01/19 PHP
Yii2实现UploadedFile上传文件示例
2017/02/15 PHP
详解Yii2.0使用AR联表查询实例
2017/06/16 PHP
JavaScript 一行代码,轻松搞定浮动快捷留言-V2升级版
2010/04/02 Javascript
JavaScript 垃圾回收机制分析
2013/10/10 Javascript
JS中实现简单Formatter函数示例代码
2014/08/19 Javascript
javascript实现连续赋值
2015/08/10 Javascript
纯javascript模仿微信打飞机小游戏
2015/08/20 Javascript
jQuery实现背景弹性滚动的导航效果
2016/06/01 Javascript
AngularJS 与Bootstrap实现表格分页实例代码
2016/10/14 Javascript
详解JavaScript中return的用法
2017/05/08 Javascript
nodejs 简单实现动态html的方法
2018/05/12 NodeJs
vue中使用better-scroll实现滑动效果及注意事项
2018/11/15 Javascript
[01:14]2014DOTA2展望TI 剑指西雅图newbee战队专访
2014/06/30 DOTA
Python实现删除Android工程中的冗余字符串
2015/01/19 Python
Python字符串匹配算法KMP实例
2015/07/18 Python
Python学习思维导图(必看篇)
2017/06/26 Python
获取python的list中含有重复值的index方法
2018/06/27 Python
Python安装tar.gz格式文件方法详解
2020/01/19 Python
解决Keras中循环使用K.ctc_decode内存不释放的问题
2020/06/29 Python
HTML5 Web Database 数据库的SQL语句的使用方法
2012/12/09 HTML / CSS
main 主函数执行完毕后,是否可能会再执行一段代码,给出说明
2012/12/05 面试题
《画家乡》教学反思
2014/04/22 职场文书
产品开发计划书
2014/04/27 职场文书
《陈毅探母》教学反思
2014/05/01 职场文书
精神文明建设先进工作者事迹材料
2014/05/02 职场文书
2014最新党员批评与自我批评材料
2014/09/24 职场文书
行政执法作风整顿剖析材料
2014/10/11 职场文书
2015年乡镇卫生院妇幼保健工作总结
2015/05/19 职场文书
2016年村干部公开承诺书(公开承诺事项)
2016/03/25 职场文书
Python用tkinter实现自定义记事本的方法详解
2022/03/31 Python
Python万能模板案例之matplotlib绘制甘特图
2022/04/13 Python