获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中函数的参数传递与可变长参数介绍
Jun 30 Python
在arcgis使用python脚本进行字段计算时是如何解决中文问题的
Oct 18 Python
十条建议帮你提高Python编程效率
Feb 16 Python
Python的Django中将文件上传至七牛云存储的代码分享
Jun 03 Python
Python简单获取二维数组行列数的方法示例
Dec 21 Python
在PyCharm下使用 ipython 交互式编程的方法
Jan 17 Python
python与字符编码问题
May 24 Python
使用Python实现毫秒级抢单功能
Jun 06 Python
python 执行终端/控制台命令的例子
Jul 12 Python
使用Python防止SQL注入攻击的实现示例
May 21 Python
Python错误的处理方法
Jun 23 Python
python和node.js生成当前时间戳的示例
Sep 29 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
PHP下操作Linux消息队列完成进程间通信的方法
2010/07/24 PHP
php保存信息到当前Session的方法
2015/03/16 PHP
PHP 中 Orientation 属性判断上传图片是否需要旋转
2015/10/16 PHP
Javascript将string类型转换int类型
2010/12/09 Javascript
JavaScript初学者应注意的七个细节小结
2012/01/30 Javascript
javascript 继承学习心得总结
2016/03/17 Javascript
Vue.js每天必学之构造器与生命周期
2016/09/05 Javascript
AngularJS中isolate scope的用法分析
2016/11/22 Javascript
js代码延迟一定时间后执行一个函数的实例
2017/02/15 Javascript
Vue项目中引入外部文件的方法(css、js、less)
2017/07/24 Javascript
vue element table 表格请求后台排序的方法
2018/09/28 Javascript
Angular4.x Event (DOM事件和自定义事件详解)
2018/10/09 Javascript
Vue插值、表达式、分隔符、指令知识小结
2018/10/12 Javascript
vue中 数字相加为字串转化为数值的例子
2019/11/07 Javascript
按日期打印Python的Tornado框架中的日志的方法
2015/05/02 Python
分析并输出Python代码依赖的库的实现代码
2015/08/09 Python
python中redis的安装和使用
2016/12/04 Python
几种实用的pythonic语法实例代码
2018/02/24 Python
python针对excel的操作技巧
2018/03/13 Python
python使用adbapi实现MySQL数据库的异步存储
2019/03/19 Python
Python对象转换为json的方法步骤
2019/04/25 Python
ASP.NET Core中的配置详解
2021/02/05 Python
html5使用Drag事件编辑器拖拽上传图片的示例代码
2017/08/22 HTML / CSS
迪奥官网:Dior.com
2018/12/04 全球购物
泰国最新活动和优惠:Megatix
2020/05/07 全球购物
决心书范文
2014/03/11 职场文书
2014年乡镇民政工作总结
2014/12/02 职场文书
限期整改通知书
2015/04/22 职场文书
2016年班主任培训心得体会
2016/01/07 职场文书
三年级作文之小小梦想
2019/12/06 职场文书
扩展多台相同的Web服务器
2021/04/01 Servers
python 提取html文本的方法
2021/05/20 Python
详解overflow:hidden的作用(溢出隐藏、清除浮动、解决外边距塌陷)
2021/07/01 HTML / CSS
Python+Selenium实现读取网易邮箱验证码
2022/03/13 Python
SQL Server Agent 服务无法启动
2022/04/20 SQL Server
Django中celery的使用项目实例
2022/07/07 Python