获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pymssql ntext字段调用问题解决方法
Dec 17 Python
python基于xmlrpc实现二进制文件传输的方法
Jun 02 Python
python中如何使用正则表达式的非贪婪模式示例
Oct 09 Python
Python语言的变量认识及操作方法
Feb 11 Python
Python标准库笔记struct模块的使用
Feb 22 Python
Python3字符串encode与decode的讲解
Apr 02 Python
Flask框架模板继承实现方法分析
Jul 31 Python
关于django 1.10 CSRF验证失败的解决方法
Aug 31 Python
git查看、创建、删除、本地、远程分支方法详解
Feb 18 Python
基于PyTorch的permute和reshape/view的区别介绍
Jun 18 Python
Python request post上传文件常见要点
Nov 20 Python
pytorch 中autograd.grad()函数的用法说明
May 12 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
如何用phpmyadmin设置mysql数据库用户的权限
2012/01/09 PHP
PHP读取txt文本文件并分页显示的方法
2015/03/11 PHP
ThinkPHP实现的rsa非对称加密类示例
2018/05/29 PHP
PHP5.5新特性之yield理解与用法实例分析
2019/01/11 PHP
php字符串截取函数mb_substr用法实例分析
2019/06/25 PHP
深入分析PHP设计模式
2020/06/15 PHP
关于Mozilla浏览器不支持innerText的解决办法
2011/01/01 Javascript
javascript游戏开发之《三国志曹操传》零部件开发(三)情景对话中仿打字机输出文字
2013/01/23 Javascript
JS阻止冒泡事件以及默认事件发生的简单方法
2014/01/17 Javascript
JavaScript实现仿网易通行证表单验证
2015/05/25 Javascript
jquery+ajax+text文本框实现智能提示完整实例
2016/07/09 Javascript
JavaScript ES6中CLASS的使用详解
2016/11/22 Javascript
浅析Visual Studio Code断点调试Vue
2018/02/27 Javascript
vue中v-model的应用及使用详解
2018/06/27 Javascript
每周一练 之 数据结构与算法(Stack)
2019/04/16 Javascript
JavaScript 禁止用户保存图片的实现代码
2020/04/28 Javascript
[43:24]2018DOTA2亚洲邀请赛3月29日 小组赛A组 LGD VS Liquid
2018/03/30 DOTA
Python 解析XML文件
2009/04/15 Python
python数据结构树和二叉树简介
2014/04/29 Python
pymongo为mongodb数据库添加索引的方法
2015/05/11 Python
分享几道你可能遇到的python面试题
2017/07/24 Python
python 文件转成16进制数组的实例
2018/07/09 Python
Python实现字符串匹配的KMP算法
2019/04/04 Python
python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)
2020/02/09 Python
详解pyinstaller生成exe的闪退问题解决方案
2020/06/19 Python
Python解析微信dat文件的方法
2020/11/30 Python
详解Canvas 跨域脱坑实践
2018/11/07 HTML / CSS
俄罗斯运动鞋商店:Sneakerhead
2018/05/10 全球购物
表演方阵解说词
2014/02/08 职场文书
我爱我校演讲稿
2014/05/21 职场文书
校运动会广播稿(100篇)
2014/09/12 职场文书
反四风个人对照检查材料思想汇报
2014/09/25 职场文书
公司总经理岗位职责
2015/04/01 职场文书
2016党员党课心得体会
2016/01/07 职场文书
职场:企业印章管理制度(模板)
2019/10/18 职场文书
总结Python变量的相关知识
2021/06/28 Python