pytorch 可视化feature map的示例代码


Posted in Python onAugust 20, 2019

之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。

不多说了,直接看代码:

import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle

from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model

model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))



class myNet(nn.Module):
  def __init__(self,pretrained_model,layers):
    super(myNet,self).__init__()
    self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
    self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
    self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])

  def forward(self,x):
    out1 = self.net1(x)
    out2 = self.net(out1)
    out3 = self.net(out2)
    return out1,out2,out3

def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model 
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。

'''
  net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) 
#  print net1 
  out1 = net1(x) 

  net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]]) 
#  print net2 
  out2 = net2(out1) 

  #net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]]) 
  #out3 = net3(out2) 

  return out1, out2
with open('test.pickle','rb') as f:
  data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
  x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
  model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])

以上这篇pytorch 可视化feature map的示例代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入解析Python中的__builtins__内建对象
Jun 21 Python
Python中int()函数的用法浅析
Oct 17 Python
50行Python代码实现人脸检测功能
Jan 23 Python
Python装饰器用法示例小结
Feb 11 Python
详解如何在python中读写和存储matlab的数据文件(*.mat)
Feb 24 Python
浅谈python numpy中nonzero()的用法
Apr 02 Python
python实现n个数中选出m个数的方法
Nov 13 Python
django做form表单的数据验证过程详解
Jul 26 Python
Ubuntu下Python+Flask分分钟搭建自己的服务器教程
Nov 19 Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 Python
基于python的matplotlib制作双Y轴图
Apr 20 Python
Python爬虫入门案例之回车桌面壁纸网美女图片采集
Oct 16 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 #Python
python爬虫 urllib模块url编码处理详解
Aug 20 #Python
pytorch实现用Resnet提取特征并保存为txt文件的方法
Aug 20 #Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
You might like
剖析 PHP 中的输出缓冲
2006/12/21 PHP
thinkphp模板的包含与渲染实例分析
2014/11/26 PHP
功能强大的php分页函数
2016/07/20 PHP
laravel实现按时间日期进行分组统计方法示例
2019/03/23 PHP
由JavaScript技术实现的web小游戏(不含网游)
2010/06/12 Javascript
Node.js 应用跑得更快 10 个技巧
2016/04/03 Javascript
javascript中数组和字符串的方法对比
2016/07/20 Javascript
jquery判断对象是否为空并遍历对象的简单实例
2016/07/26 Javascript
JS+CSS3模拟溢出滚动效果
2016/08/12 Javascript
node.js与C语言 实现遍历文件夹下最大的文件,并输出路径,大小
2017/01/20 Javascript
jQuery实现简单的滑动导航代码(移动端)
2017/05/22 jQuery
Angular.js自动化测试之protractor详解
2017/07/07 Javascript
详解vue-cli3使用
2018/08/14 Javascript
vue踩坑记-在项目中安装依赖模块npm install报错
2019/04/02 Javascript
Vue 3自定义指令开发的相关总结
2021/01/29 Vue.js
[34:10]Secret vs VG 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.24
2019/09/10 DOTA
python使用paramiko模块实现ssh远程登陆上传文件并执行
2014/01/27 Python
python定时器使用示例分享
2014/02/16 Python
深入理解python中的浅拷贝和深拷贝
2016/05/30 Python
Python爬取当当、京东、亚马逊图书信息代码实例
2017/12/09 Python
Python numpy.array()生成相同元素数组的示例
2018/11/12 Python
Python 支付整合开发包的实现
2019/01/23 Python
pytorch神经网络之卷积层与全连接层参数的设置方法
2019/08/18 Python
详解Django配置优化方法
2019/11/18 Python
使用Python 自动生成 Word 文档的教程
2020/02/13 Python
css3 transform过渡抖动问题解决
2020/10/23 HTML / CSS
Html5原生拖拽相关事件简介以及基础实现
2020/11/19 HTML / CSS
微软美国官方网站:Microsoft美国
2018/05/10 全球购物
一年级家长会邀请函
2014/01/25 职场文书
课堂教学改革实施方案
2014/03/17 职场文书
北京离婚协议书范文2014
2014/09/29 职场文书
社团个人总结范文
2015/03/05 职场文书
赞助商致辞
2015/07/30 职场文书
2016大一新生军训心得体会
2016/01/11 职场文书
MySQL远程无法连接的一些常见原因总结
2022/09/23 MySQL
CSS使用SVG实现动态分布的圆环发散路径动画
2022/12/24 HTML / CSS