获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyramid将models.py文件的内容分布到多个文件的方法
Nov 27 Python
Python中使用PyHook监听鼠标和键盘事件实例
Jul 18 Python
Python模仿POST提交HTTP数据及使用Cookie值的方法
Nov 10 Python
简单掌握Python的Collections模块中counter结构的用法
Jul 07 Python
Linux RedHat下安装Python2.7开发环境
May 20 Python
python简单实现AES加密和解密
Mar 28 Python
python字符串切割:str.split()与re.split()的对比分析
Jul 16 Python
利用python实现短信和电话提醒功能的例子
Aug 08 Python
django模板获取list中指定索引的值方式
May 14 Python
使用pytorch实现论文中的unet网络
Jun 24 Python
安装并免费使用Pycharm专业版(学生/教师)
Sep 24 Python
python获取字符串中的email
Mar 31 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
PHP 截取字符串函数整理(支持gb2312和utf-8)
2010/02/16 PHP
求PHP数组最大值,最小值的代码
2011/10/31 PHP
php eval函数用法 PHP中eval()函数小技巧
2012/10/31 PHP
php删除与复制文件夹及其文件夹下所有文件的实现代码
2013/01/23 PHP
ThinkPHP之M方法实例详解
2014/06/20 PHP
PHP数组去重比较快的实现方式
2016/01/19 PHP
php简单计算年龄的方法(周岁与虚岁)
2016/12/06 PHP
微信推送功能实现方式图文详解
2019/07/12 PHP
jQuery 1.3 和 Validation 验证插件1.5.1
2009/07/09 Javascript
Javascript表达式中连续的 && 和 || 之赋值区别
2010/10/17 Javascript
使用Jquery搭建最佳用户体验的登录页面之记住密码自动登录功能(含后台代码)
2011/07/10 Javascript
JS常用正则表达式总结
2013/11/12 Javascript
JavaScript中指定函数名称的相关方法
2015/06/04 Javascript
Javascript函数的参数
2015/07/16 Javascript
javascript绘制漂亮的心型线效果完整实例
2016/02/02 Javascript
jquery实现数字输入框
2017/02/22 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
nodejs实现截取上传视频中一帧作为预览图片
2017/12/10 NodeJs
详解小程序退出页面时清除定时器
2019/04/28 Javascript
用Python编写web API的教程
2015/04/30 Python
Python+PIL实现支付宝AR红包
2018/02/09 Python
在双python下设置python3为默认的方法
2018/10/31 Python
Python批量修改图片分辨率的实例代码
2019/07/04 Python
pytorch: Parameter 的数据结构实例
2019/12/31 Python
python 在sql语句中使用%s,%d,%f说明
2020/06/06 Python
Python 常用日期处理 -- calendar 与 dateutil 模块的使用
2020/09/02 Python
纯css3实现的竖形无限级导航
2014/12/10 HTML / CSS
Casetify官网:自制专属手机壳、iPad护壳和Apple Watch手表带
2018/05/09 全球购物
美体小铺印度官网:The Body Shop印度
2019/10/17 全球购物
客服服务心得体会
2013/12/30 职场文书
社区巾帼文明岗事迹材料
2014/06/03 职场文书
启动仪式策划方案
2014/06/14 职场文书
全国优秀教师事迹材料
2014/08/26 职场文书
办公楼租房协议书范本
2014/11/25 职场文书
2015年政务公开工作总结
2015/05/19 职场文书
2016年寒假学习心得体会
2015/10/09 职场文书