获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量生成本地ip地址的方法
Mar 23 Python
Python的Flask框架中web表单的教程
Apr 20 Python
Python可变参数函数用法实例
Jul 07 Python
Python实现的RSS阅读器实例
Jul 25 Python
python列表操作之extend和append的区别实例分析
Jul 28 Python
浅谈python中requests模块导入的问题
May 18 Python
利用python和ffmpeg 批量将其他图片转换为.yuv格式的方法
Jan 08 Python
Python 数据库操作 SQLAlchemy的示例代码
Feb 18 Python
python实现高斯判别分析算法的例子
Dec 09 Python
python 实现屏幕录制示例
Dec 23 Python
python 错误处理 assert详解
Apr 20 Python
pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异
Feb 25 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
PHP下常用正则表达式整理
2010/10/26 PHP
php中require和require_once的区别说明
2014/02/27 PHP
codeigniter中view通过循环显示数组数据的方法
2015/03/20 PHP
PHP抓取网页、解析HTML常用的方法总结
2015/07/01 PHP
PHP getName()函数讲解
2019/02/03 PHP
PHP实现二维数组按照指定的字段进行排序算法示例
2019/04/23 PHP
php+iframe 实现上传文件功能示例
2020/03/04 PHP
checkbox 多选框 联动实现代码
2008/10/22 Javascript
JQuery+JS实现仿百度搜索结果中关键字变色效果
2011/08/02 Javascript
50个比较实用jQuery代码段
2011/09/18 Javascript
javascript椭圆旋转相册实现代码
2012/01/16 Javascript
改进版通过Json对象实现深复制的方法
2012/10/24 Javascript
JS无限极树形菜单,json格式、数组格式通用示例
2013/07/30 Javascript
正则表达式(语法篇推荐)
2016/06/24 Javascript
JavaScript中的this使用详解
2016/07/27 Javascript
AngularJS过滤器filter用法分析
2016/12/11 Javascript
jQuery实用密码强度检测
2017/03/02 Javascript
vue实现百度搜索下拉提示功能实例
2017/06/14 Javascript
vue-router的HTML5 History 模式设置
2018/09/08 Javascript
Vue.js更改调试地址端口号的实例
2018/09/19 Javascript
javascript实现图片轮播代码
2019/07/09 Javascript
Vue中避免滥用this去读取data中数据
2021/03/02 Vue.js
[03:17]史诗级大片应援2018DOTA2国际邀请赛 致敬每一位坚守遗迹的勇士
2018/07/20 DOTA
python简单判断序列是否为空的方法
2015/06/30 Python
python 重定向获取真实url的方法
2018/05/11 Python
Python从单元素字典中获取key和value的实例
2018/12/31 Python
Django RBAC权限管理设计过程详解
2019/08/06 Python
python pprint模块中print()和pprint()两者的区别
2020/02/10 Python
英国著名药妆店:Superdrug
2021/02/13 全球购物
经理管理专业自荐信范文
2013/12/31 职场文书
贺卡寄语大全
2014/04/11 职场文书
幼儿教师师德演讲稿
2014/05/06 职场文书
三严三实学习心得体会
2014/10/13 职场文书
结婚堵门保证书
2015/05/08 职场文书
研究生学习计划书应该怎么写?
2019/09/10 职场文书
搭建zabbix监控以及邮件报警的超级详细教学
2022/07/15 Servers