获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用百度音乐搜索的api下载指定歌曲的lrc歌词
Jul 18 Python
python中的全局变量用法分析
Jun 09 Python
详解python里使用正则表达式的全匹配功能
Oct 19 Python
Django REST为文件属性输出完整URL的方法
Dec 18 Python
对pandas进行数据预处理的实例讲解
Apr 20 Python
Python基于win32ui模块创建弹出式菜单示例
May 09 Python
Django框架多表查询实例分析
Jul 04 Python
python中字符串内置函数的用法总结
Sep 13 Python
50行Python代码实现视频中物体颜色识别和跟踪(必须以红色为例)
Nov 20 Python
Python3 filecmp模块测试比较文件原理解析
Mar 23 Python
简单了解python调用其他脚本方法实例
Mar 26 Python
Python常遇到的错误和异常
Nov 02 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
Yii框架扩展CGridView增加导出CSV功能的方法
2017/05/24 PHP
JavaScript 脚本将当地时间转换成其它时区
2009/03/19 Javascript
jQuery 性能优化手册 推荐
2010/02/23 Javascript
JavaScript高级程序设计 读书笔记之八 Function类及闭包
2012/02/27 Javascript
ie支持function.bind()方法实现代码
2012/12/27 Javascript
js根据给定的日期计算当月有多少天实现思路及代码
2013/02/25 Javascript
JavaScript函数定义的常见注意事项小结
2014/09/16 Javascript
javascript实时显示北京时间的方法
2015/03/12 Javascript
浅谈jQuery构造函数分析
2015/05/11 Javascript
分享JavaScript与Java中MD5使用两个例子
2015/12/23 Javascript
Vue的百度地图插件尝试使用
2017/09/06 Javascript
通过js控制时间,一秒一秒自己动的实例
2017/10/25 Javascript
js中document.write和document.writeln的区别
2018/03/11 Javascript
vue计算属性+vue中class与style绑定(推荐)
2020/03/30 Javascript
Python正则表达式完全指南
2017/05/25 Python
Django实现的自定义访问日志模块示例
2017/06/23 Python
python try except 捕获所有异常的实例
2018/10/18 Python
python实现AES加密与解密
2019/03/28 Python
python 比较2张图片的相似度的方法示例
2019/12/18 Python
python selenium自动化测试框架搭建的方法步骤
2020/06/14 Python
Python学习之路之pycharm的第一个项目搭建过程
2020/06/18 Python
Python map及filter函数使用方法解析
2020/08/06 Python
Python使用socket模块实现简单tcp通信
2020/08/18 Python
python爬取天气数据的实例详解
2020/11/20 Python
枚举和一组预处理的#define有什么不同
2016/09/21 面试题
中专生自我鉴定
2013/12/17 职场文书
大学系主任推荐信范文
2013/12/24 职场文书
最新创业融资计划书
2014/01/19 职场文书
学历公证书范本
2014/04/09 职场文书
爱心捐款倡议书
2014/04/14 职场文书
小学一年级学生评语
2014/04/22 职场文书
党员组织生活会发言材料
2014/10/17 职场文书
2015年党员承诺书
2015/01/21 职场文书
消防安全月活动总结
2015/05/08 职场文书
大学学习委员竞选稿
2015/11/20 职场文书
什么是Python装饰器?如何定义和使用?
2022/04/11 Python