获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现一个简单的MySQL类
Jan 07 Python
python自然语言编码转换模块codecs介绍
Apr 08 Python
快速排序的算法思想及Python版快速排序的实现示例
Jul 02 Python
django 多数据库配置教程
May 30 Python
在PyCharm的 Terminal(终端)切换Python版本的方法
Aug 02 Python
python 实现Flask中返回图片流给前端展示
Jan 09 Python
python序列类型种类详解
Feb 26 Python
更新升级python和pip版本后不生效的问题解决
Apr 17 Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 Python
使用keras2.0 将Merge层改为函数式
May 23 Python
浅析Python 条件控制语句
Jul 15 Python
Python实现天气查询软件
Jun 07 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
无线电波是什么?它是怎样传输的?
2021/03/01 无线电
mysql5的sql文件导入到mysql4的方法
2008/10/19 PHP
thinkphp分页实现效果
2016/10/13 PHP
PHP按符号截取字符串的指定部分的实现方法
2018/09/10 PHP
javascript XML数据显示为HTML一例
2008/12/23 Javascript
Javascript Global对象
2009/08/13 Javascript
关于js注册事件的常用方法
2013/04/03 Javascript
使用js实现按钮控制文本框加1减1应用于小时+分钟
2013/12/09 Javascript
jquery淡化版banner异步图片文字效果切换图片特效
2014/04/08 Javascript
JavaScript中实现PHP的打乱数组函数shuffle实例
2014/10/11 Javascript
js中的json对象详细介绍
2014/10/29 Javascript
JavaScript实现向右伸出的多级网页菜单效果
2015/08/25 Javascript
js实现的动画导航菜单效果代码
2015/09/10 Javascript
Bootstrap实现响应式导航栏效果
2015/12/28 Javascript
jQuery添加和删除输入文本框标签代码
2016/05/20 Javascript
Bootstrap实现的经典栅格布局效果实例【附demo源码】
2017/03/30 Javascript
JavaScript实现简单音乐播放器
2020/04/17 Javascript
vue分页插件的使用方法
2019/12/25 Javascript
[49:02]KG vs Infamous 2019国际邀请赛淘汰赛 败者组BO1 8.20.mp4
2020/07/19 DOTA
在Linux下使用Python的matplotlib绘制数据图的教程
2015/06/11 Python
Python使用sftp实现上传和下载功能(实例代码)
2017/03/14 Python
Tornado高并发处理方法实例代码
2018/01/15 Python
python 利用pywifi模块实现连接网络破解wifi密码实时监控网络
2019/09/16 Python
python构建指数平滑预测模型示例
2019/11/21 Python
Python实现CNN的多通道输入实例
2020/01/17 Python
Python Numpy中数据的常用保存与读取方法
2020/04/01 Python
CSS3解析抖音LOGO制作的方法步骤
2019/04/11 HTML / CSS
HTML5 textarea高度自适应的两种方案
2020/04/08 HTML / CSS
巴西购物网站:Submarino
2020/01/19 全球购物
甜品店的创业计划书范文
2014/01/02 职场文书
置业顾问岗位职责
2014/03/02 职场文书
汽车运用工程专业求职信
2014/06/18 职场文书
小学语文教研活动总结
2014/07/01 职场文书
高中校园广播稿
2014/10/21 职场文书
2015年秋季灭鼠工作总结
2015/07/27 职场文书
vue+elementUI实现表格列的显示与隐藏
2022/04/13 Vue.js