获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作SQLite简明教程
Jul 10 Python
2018年Python值得关注的开源库、工具和开发者(总结篇)
Jan 04 Python
浅谈Python黑帽子取代netcat
Feb 10 Python
Python+request+unittest实现接口测试框架集成实例
Mar 16 Python
numpy.where() 用法详解
May 27 Python
python全栈知识点总结
Jul 01 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 Python
Django密码存储策略分析
Jan 09 Python
如何基于python实现归一化处理
Jan 20 Python
python 如何利用argparse解析命令行参数
Sep 11 Python
python 实现数据库中数据添加、查询与更新的示例代码
Dec 07 Python
python 通过pip freeze、dowload打离线包及自动安装的过程详解(适用于保密的离线环境
Dec 14 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
经典的星际争霸,满是回忆的BGM
2020/04/09 星际争霸
WinXP + Apache +PHP5 + MySQL + phpMyAdmin安装全功略
2006/07/09 PHP
php中DOMDocument简单用法示例代码(XML创建、添加、删除、修改)
2010/12/19 PHP
PHP命名空间(namespace)的使用基础及示例
2014/08/18 PHP
PHP模板引擎Smarty内建函数foreach,foreachelse用法分析
2016/04/11 PHP
JQuery 网站换肤功能实现代码
2009/11/02 Javascript
javascript 面向对象继承
2009/11/26 Javascript
验证码按回车不变解决方法
2013/03/29 Javascript
Jqgrid表格随窗口大小改变而改变的简单实例
2013/12/28 Javascript
Javascript实现图片轮播效果(一)让图片跳动起来
2016/02/17 Javascript
微信小程序 http请求详细介绍
2016/10/09 Javascript
微信小程序 loading(加载中提示框)实例
2016/10/28 Javascript
Vue.js进行查询操作的实例详解
2017/08/25 Javascript
详解基于Angular4+ server render(服务端渲染)开发教程
2017/08/28 Javascript
layui table 参数设置方法
2018/08/14 Javascript
小程序实现多选框功能
2018/10/30 Javascript
[03:59]DOTA2英雄梦之声_第07期_水晶室女
2014/06/23 DOTA
解决Python requests 报错方法集锦
2017/03/19 Python
Python实现手写一个类似django的web框架示例
2018/07/20 Python
浅谈pytorch和Numpy的区别以及相互转换方法
2018/07/26 Python
selenium+python设置爬虫代理IP的方法
2018/11/29 Python
selenium+python自动化测试之环境搭建
2019/01/23 Python
python飞机大战 pygame游戏创建快速入门详解
2019/12/17 Python
python scatter函数用法实例详解
2020/02/11 Python
jupyter notebook 参数传递给shell命令行实例
2020/04/10 Python
pytorch中的weight-initilzation用法
2020/06/24 Python
Django用内置方法实现简单搜索功能的方法
2020/12/18 Python
python画图时设置分辨率和画布大小的实现(plt.figure())
2021/01/08 Python
基于HTML5 Canvas的3D动态Chart图表的示例
2017/11/02 HTML / CSS
Ralph Lauren法国官网:美国高品味时装品牌
2017/12/08 全球购物
会计专业自荐信范文
2013/12/02 职场文书
蜜蜂引路教学反思
2014/02/04 职场文书
《白鹅》教学反思
2014/04/13 职场文书
一个都不能少观后感
2015/06/04 职场文书
python实战之用emoji表情生成文字
2021/05/08 Python
JUnit5常用注解的使用
2021/07/02 Java/Android