获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程之TCP通信实例和socketserver框架使用例子
Apr 25 Python
解决python大批量读写.doc文件的问题
May 08 Python
python爬虫之线程池和进程池功能与用法详解
Aug 02 Python
Python Excel处理库openpyxl使用详解
May 09 Python
Python GUI编程 文本弹窗的实例
Jun 11 Python
pycharm运行scrapy过程图解
Nov 22 Python
Python基于pyecharts实现关联图绘制
Mar 27 Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 Python
python 爬虫之selenium可视化爬虫的实现
Dec 04 Python
Python调用高德API实现批量地址转经纬度并写入表格的功能
Jan 12 Python
Python使用protobuf序列化和反序列化的实现
May 19 Python
Python实现简单的猜单词
Jun 15 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
php桌面中心(二) 数据库写入
2007/03/11 PHP
PDO版本问题 Invalid parameter number: no parameters were bound
2013/01/06 PHP
php生成html文件方法总结
2014/12/01 PHP
HR vs CL BO3 第二场 2.13
2021/03/10 DOTA
window.location.hash 使用说明
2010/11/08 Javascript
jQuery+CSS 半开折叠效果原理及代码(自写)
2013/03/04 Javascript
jquery 清空file域示例(兼容个浏览器)
2013/10/11 Javascript
jQuery实现的图片分组切换焦点图插件
2015/01/06 Javascript
javascript显式类型转换实例分析
2015/04/25 Javascript
JavaScript实现select添加option
2015/07/03 Javascript
基于javascript的Form表单验证
2016/12/29 Javascript
Angular在一个页面中使用两个ng-app的方法
2017/02/20 Javascript
Google 爬虫如何抓取 JavaScript 的内容
2017/04/07 Javascript
微信小程序图片宽100%显示并且不变形
2017/06/21 Javascript
AngularJS实现注册表单验证功能
2017/10/16 Javascript
几个你不知道的技巧助你写出更优雅的vue.js代码
2018/06/11 Javascript
详解一个基于套接字实现长连接的express
2019/03/28 Javascript
vuex(vue状态管理)的特殊应用案例分享
2020/03/03 Javascript
Paypal支付不完全指北
2020/06/04 Javascript
javascript自定义加载loading效果
2020/09/15 Javascript
[06:37]2014DOTA2国际邀请赛 昔日王者渴望重回巅峰
2014/07/12 DOTA
Python使用PIL库实现验证码图片的方法
2016/03/11 Python
python和flask中返回JSON数据的方法
2018/03/26 Python
matplotlib.pyplot画图 图片的二进制流的获取方法
2018/05/24 Python
Python3.5多进程原理与用法实例分析
2019/04/05 Python
python实现拉普拉斯特征图降维示例
2019/11/25 Python
Pytorch实现各种2d卷积示例
2019/12/30 Python
Python使用Pandas库常见操作详解
2020/01/16 Python
Python如何进行时间处理
2020/08/06 Python
莫斯科珠宝厂官方网站:Miuz
2020/09/19 全球购物
构造器Constructor是否可被override?
2013/08/06 面试题
Android面试题附答案
2014/12/08 面试题
民族团结先进个人材料
2014/02/05 职场文书
家长会演讲稿
2014/04/26 职场文书
2014年作风建设心得体会
2014/10/22 职场文书
2016应届毕业生就业指导课心得体会
2016/01/15 职场文书