获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
介绍Python的@property装饰器的用法
Apr 28 Python
Python、PyCharm安装及使用方法(Mac版)详解
Apr 28 Python
python将文本分每两行一组并保存到文件
Mar 19 Python
python监控进程脚本
Apr 12 Python
Python获取昨天、今天、明天开始、结束时间戳的方法
Jun 01 Python
用Django写天气预报查询网站
Oct 21 Python
python实现一个简单的udp通信的示例代码
Feb 01 Python
python conda操作方法
Sep 11 Python
Pytorch 搭建分类回归神经网络并用GPU进行加速的例子
Jan 09 Python
Python3爬虫发送请求的知识点实例
Jul 30 Python
详解python算法常用技巧与内置库
Oct 17 Python
Pytorch实验常用代码段汇总
Nov 19 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
十大催泪虐心动漫电影,有几部你还没看
2020/03/04 日漫
解析posix与perl标准的正则表达式区别
2013/06/17 PHP
PHP实现支持GET,POST,Multipart/form-data的HTTP请求类
2014/09/24 PHP
PHP使用Pear发送邮件(Windows环境)
2016/01/05 PHP
php求数组全排列,元素所有组合的方法总结
2017/03/14 PHP
Ubuntu 16.04下安装PHP 7过程详解
2017/03/28 PHP
基于jquery的has()方法以及与find()方法以及filter()方法的区别详解
2013/04/26 Javascript
运用jQuery定时器的原理实现banner图片切换
2014/10/22 Javascript
jquery trigger实现联动的方法
2016/02/29 Javascript
实例讲解Jquery中隐藏hide、显示show、切换toggle的用法
2016/05/13 Javascript
jQuery实现的图片轮播效果完整示例
2016/09/12 Javascript
JS判断Android、iOS或浏览器的多种方法(四种方法)
2017/06/29 Javascript
让你彻底掌握es6 Promise的八段代码
2017/07/26 Javascript
引入JavaScript时alert弹出框显示中文乱码问题
2017/09/16 Javascript
使用JS获取SessionStorage的值
2018/01/12 Javascript
node.js实现为PDF添加水印的示例代码
2018/12/05 Javascript
了解JavaScript表单操作和表单域
2019/05/27 Javascript
ElementUI中el-tree节点的操作的实现
2020/02/27 Javascript
python解决方案:WindowsError: [Error 2]
2016/08/28 Python
Python竟能画这么漂亮的花,帅呆了(代码分享)
2017/11/15 Python
Python复制Word内容并使用格式设字体与大小实例代码
2018/01/22 Python
python 制作自定义包并安装到系统目录的方法
2018/10/27 Python
如何安装并使用conda指令管理python环境
2019/07/10 Python
解决Python二维数组赋值问题
2019/11/28 Python
如何利用CSS3制作3D效果文字具体实现样式
2013/05/02 HTML / CSS
美国鞋类购物网站:Shiekh Shoes
2016/08/21 全球购物
JINS眼镜官方网站:日本最大的眼镜邮购
2016/10/14 全球购物
村居抓节水倡议书
2014/05/19 职场文书
摄影专业毕业生求职信
2014/08/05 职场文书
购房意向书
2014/08/30 职场文书
学习党代会心得体会
2014/09/05 职场文书
党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
经验交流材料格式
2014/12/30 职场文书
劳保用品管理制度范本
2015/08/06 职场文书
温馨祝福晨语:美丽的一天从我的问候开始
2019/11/28 职场文书
docker compose 部署 golang 的 Athens 私有代理问题
2022/04/28 Servers