pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现数通设备tftp备份配置文件示例
Apr 02 Python
利用Django框架中select_related和prefetch_related函数对数据库查询优化
Apr 01 Python
python操作mysql数据库
Mar 05 Python
python中Pycharm 输出中文或打印中文乱码现象的解决办法
Jun 16 Python
详解 Python 读写XML文件的实例
Aug 02 Python
利用python为运维人员写一个监控脚本
Mar 25 Python
python3+opencv3识别图片中的物体并截取的方法
Dec 05 Python
python try 异常处理(史上最全)
Mar 07 Python
Python IDE Pycharm中的快捷键列表用法
Aug 08 Python
python scrapy重复执行实现代码详解
Dec 28 Python
ubuntu 安装pyqt5和卸载pyQt5的方法
Mar 24 Python
新手学习Python2和Python3中print不同的用法
Jun 09 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
在PHP里得到前天和昨天的日期的代码
2007/08/16 PHP
DISCUZ 论坛管理员密码忘记的解决方法
2009/05/14 PHP
Codeigniter操作数据库表的优化写法总结
2014/06/12 PHP
使用PHP生成PDF方法详解
2015/01/23 PHP
PHP会员找回密码功能的简单实现
2016/09/05 PHP
php的socket编程详解
2016/11/20 PHP
利用phpexcel对数据库数据的导入excel(excel筛选)、导出excel
2017/04/27 PHP
Laravel框架模板继承操作示例
2018/06/11 PHP
JQuery 选择器 xpath 语法应用
2010/05/13 Javascript
JavaScript仿聊天室聊天记录
2016/12/27 Javascript
Angular组件化管理实现方法分析
2017/03/17 Javascript
angular2路由切换改变页面title的示例代码
2017/08/23 Javascript
初识 Vue.js 中的 *.Vue文件
2017/11/22 Javascript
jQuery EasyUI 折叠面板accordion的使用实例(分享)
2017/12/25 jQuery
Vue中v-for的数据分组实例
2018/03/07 Javascript
vue的mixins属性详解
2018/03/14 Javascript
基于IView中on-change属性的使用详解
2018/03/15 Javascript
js拖动滑块和点击水波纹效果实例代码
2018/10/16 Javascript
关于layui表单中按钮自动提交的解决方法
2019/09/09 Javascript
8个非常实用的Vue自定义指令
2020/12/15 Vue.js
[02:48]DOTA2英雄基础教程 拉席克
2013/12/12 DOTA
Python中的条件判断语句基础学习教程
2016/02/07 Python
python添加模块搜索路径方法
2017/09/11 Python
Python基于高斯消元法计算线性方程组示例
2018/01/17 Python
基于python3 OpenCV3实现静态图片人脸识别
2018/05/25 Python
opencv python 傅里叶变换的使用
2018/07/21 Python
python提取具有某种特定字符串的行数据方法
2018/12/11 Python
Python中asyncio模块的深入讲解
2019/06/10 Python
使用python实现哈希表、字典、集合操作
2019/12/22 Python
Dune London官网:英国著名奢华鞋履品牌
2017/11/30 全球购物
团队精神的演讲稿
2014/05/14 职场文书
节约用电标语
2014/06/17 职场文书
暑期社会实践心得体会
2014/09/02 职场文书
2015年感恩父亲节活动策划方案
2015/05/05 职场文书
python引入其他文件夹下的py文件具体方法
2021/05/23 Python
mongodb清除连接和日志的正确方法分享
2021/09/15 MongoDB