获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取beautifulphoto随机某图片代码实例
Dec 18 Python
Python去掉字符串中空格的方法
Mar 11 Python
浅谈python numpy中nonzero()的用法
Apr 02 Python
Python获取昨天、今天、明天开始、结束时间戳的方法
Jun 01 Python
pycharm中成功运行图片的配置教程
Oct 28 Python
python opencv minAreaRect 生成最小外接矩形的方法
Jul 01 Python
Python中的self用法详解
Aug 06 Python
Python的bit_length函数来二进制的位数方法
Aug 27 Python
Python自省及反射原理实例详解
Jul 06 Python
python基于tkinter制作下班倒计时工具
Apr 28 Python
解决python3安装pandas出错的问题
May 20 Python
Python实现PIL图像处理库绘制国际象棋棋盘
Jul 16 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
火影忍者:这才是千手柱间和扉间的真正死因,角都就比较搞笑了!
2020/03/10 日漫
解析Linux下Varnish缓存的配置优化
2013/06/20 PHP
php数组编码转换示例详解
2014/03/11 PHP
php使用Imagick生成图片的方法
2015/07/31 PHP
PHP多维数组转一维数组的简单实现方法
2015/12/23 PHP
prototype 1.5相关知识及他人笔记
2006/12/16 Javascript
css值转换成数值请抛弃parseInt
2011/10/24 Javascript
可编辑下拉框的2种实现方式
2014/06/13 Javascript
JS实现关键字搜索时的相关下拉字段效果
2014/08/05 Javascript
原生javaScript实现图片延时加载的方法
2014/12/22 Javascript
jQuery操作JSON的CRUD用法实例
2015/02/25 Javascript
JS实现控制表格只显示行边框或者只显示列边框的方法
2015/03/31 Javascript
javascript限制文本框输入值类型的方法
2015/05/07 Javascript
angularjs中的$eval方法详解
2017/04/24 Javascript
详解关于vue2.0工程发布上线操作步骤
2018/09/27 Javascript
vue之debounce属性被移除及处理详解
2019/11/13 Javascript
基于NodeJS开发钉钉回调接口实现AES-CBC加解密
2020/08/20 NodeJs
详解Python的Django框架中的Cookie相关处理
2015/07/22 Python
Python数据结构与算法之图的最短路径(Dijkstra算法)完整实例
2017/12/12 Python
python 字典中取值的两种方法小结
2018/08/02 Python
pandas dataframe添加表格框线输出的方法
2019/02/08 Python
Django-Model数据库操作(增删改查、连表结构)详解
2019/07/17 Python
Python SSL证书验证问题解决方案
2020/01/13 Python
Pycharm最常用的快捷键及使用技巧
2020/03/05 Python
python实现手势识别的示例(入门)
2020/04/15 Python
Python中使用socks5设置全局代理的方法示例
2020/04/15 Python
西班牙国家航空官方网站:Iberia
2017/11/16 全球购物
巴西美妆购物网站:Kutiz Beauté
2019/03/13 全球购物
澳大利亚有机化妆品网上商店:The Well Store
2020/02/20 全球购物
大专应届生个人的自我评价
2013/11/21 职场文书
个人简历自我评价范文
2014/02/04 职场文书
国培远程培训感言
2014/03/08 职场文书
《泉水》教学反思
2014/04/11 职场文书
大学生评语大全
2014/04/18 职场文书
优秀班主任事迹材料
2014/12/16 职场文书
浅谈 JavaScript 沙箱Sandbox
2021/11/02 Javascript