获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python命令启动Web服务器实例详解
Feb 23 Python
Python中正则表达式详解
May 17 Python
Python虚拟环境项目实例
Nov 20 Python
Python模拟随机游走图形效果示例
Feb 06 Python
opencv python 图像轮廓/检测轮廓/绘制轮廓的方法
Jul 03 Python
Django文件存储 默认存储系统解析
Aug 02 Python
python实现多进程通信实例分析
Sep 01 Python
python装饰器的特性原理详解
Dec 25 Python
TensorFlow tensor的拼接实例
Jan 19 Python
使用Python三角函数公式计算三角形的夹角案例
Apr 15 Python
python 浮点数四舍五入需要注意的地方
Aug 18 Python
Python 生成短8位唯一id实战教程
Jan 13 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
ajax实现无刷新分页(php)
2010/07/18 PHP
异步加载技术实现当滚动条到最底部的瀑布流效果
2014/09/16 PHP
PHP中调用SVN命令更新网站方法
2015/01/07 PHP
PHP之预定义接口详解
2015/07/29 PHP
php通过curl添加cookie伪造登陆抓取数据的方法
2016/04/02 PHP
laravel 框架实现无限级分类的方法示例
2019/10/31 PHP
js数字输入框(包括最大值最小值限制和四舍五入)
2009/11/24 Javascript
js 面向对象的技术创建高级 Web 应用程序
2010/02/25 Javascript
打开新窗口关闭当前页面不弹出关闭提示js代码
2013/03/18 Javascript
JS中引用百度地图并将百度地图的logo和信息去掉
2013/09/29 Javascript
JS判断移动端访问设备并加载对应CSS样式
2014/06/13 Javascript
js中iframe调用父页面的方法
2014/10/30 Javascript
Javascript实现的简单右键菜单类
2015/09/23 Javascript
js实现html table 行,列锁定的简单实例
2016/10/13 Javascript
Nodejs进阶:基于express+multer的文件上传实例
2016/11/21 NodeJs
JavaScript实现的select点菜功能示例
2017/01/16 Javascript
10个经典的网页鼠标特效代码
2018/01/09 Javascript
对node.js中render和send的用法详解
2018/05/14 Javascript
js实现类似iphone的网页滑屏解锁功能示例【附源码下载】
2019/06/10 Javascript
jQuery实现图片随机切换、抽奖功能(实例代码)
2019/10/23 jQuery
Webpack5正式发布,有哪些新特性
2020/10/12 Javascript
vue项目中js-cookie的使用存储token操作
2020/11/13 Javascript
用Python的SimPy库简化复杂的编程模型的介绍
2015/04/13 Python
Python实现求两个csv文件交集的方法
2017/09/06 Python
基于Python __dict__与dir()的区别详解
2017/10/30 Python
对python中for、if、while的区别与比较方法
2018/06/25 Python
python3使用pandas获取股票数据的方法
2018/12/22 Python
pytorch加载语音类自定义数据集的方法教程
2020/11/10 Python
前端实现打印图像功能
2019/08/27 HTML / CSS
药学专业大学生个人的自我评价
2013/11/04 职场文书
法警的竞聘演讲稿
2014/01/02 职场文书
财务管理专业求职信
2014/06/11 职场文书
汽车维修求职信
2014/06/15 职场文书
早读课迟到检讨书
2014/09/25 职场文书
办公室主任个人对照检查材料思想汇报
2014/10/11 职场文书
初中班级口号霸气押韵
2015/12/24 职场文书