获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python去掉字符串中空格的方法
Mar 11 Python
Python中的深拷贝和浅拷贝详解
Jun 03 Python
python中获得当前目录和上级目录的实现方法
Oct 12 Python
100行Python代码实现自动抢火车票(附源码)
Jan 11 Python
python实现linux下抓包并存库功能
Jul 18 Python
Python中的random.uniform()函数教程与实例解析
Mar 02 Python
Django框架中间件(Middleware)用法实例分析
May 24 Python
python内打印变量之%和f的实例
Feb 19 Python
python数据库操作mysql:pymysql、sqlalchemy常见用法详解
Mar 30 Python
Python中Selenium模块的使用详解
Oct 09 Python
python爬虫框架feapde的使用简介
Apr 20 Python
Python入门之使用pandas分析excel数据
May 12 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
使用PHP实现蜘蛛访问日志统计
2013/07/05 PHP
PHP模板引擎Smarty内建函数详解
2016/04/11 PHP
php验证身份证号码正确性的函数
2016/07/20 PHP
mysql desc(DESCRIBE)命令实例讲解
2016/09/24 PHP
PHP插件PHPMailer发送邮件功能
2017/02/28 PHP
document对象execCommand的command参数介绍
2006/08/01 Javascript
基于jquery的图片的切换(以数字的形式)
2011/02/14 Javascript
JavaScript使用IEEE 标准进行二进制浮点运算产生莫名错误的解决方法
2011/05/28 Javascript
JQuery 操作/获取table具体代码
2013/06/13 Javascript
jquery中load方法的用法及注意事项说明
2014/02/22 Javascript
jquery动态添加删除一行数据示例
2014/06/12 Javascript
JS实现倒计时和文字滚动的效果实例
2014/10/29 Javascript
让html页面不缓存js的实现方法
2014/10/31 Javascript
jquery checkbox 勾选的bug问题解决方案与分析
2014/11/13 Javascript
JS实现输入框提示文字点击时消失效果
2016/07/19 Javascript
探索Javascript中this的奥秘
2016/12/11 Javascript
完美解决JS文件页面加载时的阻塞问题
2016/12/18 Javascript
基于vue实现分页/翻页组件paginator示例
2017/03/09 Javascript
JavaScript 中 apply 、call 的详解
2017/03/21 Javascript
weebox弹出窗口不居中显示的解决方法
2017/11/27 Javascript
Angular2+如何去除url中的#号详解
2017/12/20 Javascript
Angular4 ElementRef的应用
2018/02/26 Javascript
js中数组对象去重的两种方法
2019/01/18 Javascript
微信小程序 行的删除和增加操作实现详解
2019/09/29 Javascript
JS使用H5实现图片预览功能
2019/09/30 Javascript
如何使用JavaScript检测空闲的浏览器选项卡
2020/05/28 Javascript
[02:27]2018DOTA2亚洲邀请赛赛前采访-OpTic
2018/04/03 DOTA
解决python3中自定义wsgi函数,make_server函数报错的问题
2017/11/21 Python
tensorflow识别自己手写数字
2018/03/14 Python
Python 实现微信防撤回功能
2019/04/29 Python
纯css3实现宠物小鸡实例代码
2018/10/08 HTML / CSS
机电工程专业应届生求职信
2013/10/03 职场文书
先进个人评语大全
2015/01/04 职场文书
2015年企业工作总结范文
2015/04/28 职场文书
2015年计划生育协会工作总结
2015/05/13 职场文书
2016幼儿园教师年度考核评语
2015/12/01 职场文书