pytorch 可视化feature map的示例代码


Posted in Python onAugust 20, 2019

之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。

不多说了,直接看代码:

import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle

from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model

model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))



class myNet(nn.Module):
  def __init__(self,pretrained_model,layers):
    super(myNet,self).__init__()
    self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
    self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
    self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])

  def forward(self,x):
    out1 = self.net1(x)
    out2 = self.net(out1)
    out3 = self.net(out2)
    return out1,out2,out3

def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model 
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。

'''
  net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) 
#  print net1 
  out1 = net1(x) 

  net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]]) 
#  print net2 
  out2 = net2(out1) 

  #net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]]) 
  #out3 = net3(out2) 

  return out1, out2
with open('test.pickle','rb') as f:
  data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
  x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
  model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])

以上这篇pytorch 可视化feature map的示例代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python将阿拉伯数字转换为罗马数字的方法
Jul 10 Python
Python解析树及树的遍历
Feb 03 Python
python爬虫爬取淘宝商品信息(selenum+phontomjs)
Feb 24 Python
python将文本中的空格替换为换行的方法
Mar 19 Python
Python切片工具pillow用法示例
Mar 30 Python
python 读取txt,json和hdf5文件的实例
Jun 05 Python
Python生成MD5值的两种方法实例分析
Apr 26 Python
ubuntu上安装python的实例方法
Sep 30 Python
Python爬虫实现百度翻译功能过程详解
May 29 Python
使用Tensorflow-GPU禁用GPU设置(CPU与GPU速度对比)
Jun 30 Python
opencv+pyQt5实现图片阈值编辑器/寻色块阈值利器
Nov 13 Python
浅谈Python数学建模之整数规划
Jun 23 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 #Python
python爬虫 urllib模块url编码处理详解
Aug 20 #Python
pytorch实现用Resnet提取特征并保存为txt文件的方法
Aug 20 #Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
You might like
解析PHP SPL标准库的用法(遍历目录,查找固定条件的文件)
2013/06/18 PHP
php实例分享之二维数组排序
2014/05/15 PHP
ThinkPHP、ZF2、Yaf、Laravel框架路由大比拼
2015/03/25 PHP
如何使用Gitblog和Markdown建自己的博客
2015/07/31 PHP
快速解决PHP调用Word组件DCOM权限的问题
2017/12/27 PHP
IE8 下的Js错误HTML Parsing Error...
2009/08/14 Javascript
JavaScript类和继承 constructor属性
2010/03/04 Javascript
javascript下对于事件、事件流、事件触发的顺序随便说说
2010/07/17 Javascript
jquery中$(#form :input)与$(#form input)的区别
2014/08/18 Javascript
12306验证码破解思路分享
2015/03/25 Javascript
jquery插件ajaxupload实现文件上传操作
2015/12/09 Javascript
jQuery简单操作cookie的插件实例
2016/01/13 Javascript
jQuery+css3实现转动的正方形效果(附demo源码下载)
2016/01/27 Javascript
AngularJS整合Springmvc、Spring、Mybatis搭建开发环境
2016/02/25 Javascript
Javascript中获取浏览器类型和操作系统版本等客户端信息常用代码
2016/06/28 Javascript
jquery判断类型是不是number类型的实例代码
2016/10/07 Javascript
Bootstrap表单控件学习使用
2017/03/07 Javascript
Angular2中select用法之设置默认值与事件详解
2017/05/07 Javascript
微信小程序“摇一摇”的实例代码
2017/07/20 Javascript
Element input树型下拉框的实现代码
2018/12/21 Javascript
在vue中实现某一些路由页面隐藏导航栏的功能操作
2020/09/21 Javascript
小程序组件传值和引入sass的方法(使用vant Weapp组件库)
2020/11/24 Javascript
python cv2读取rtsp实时码流按时生成连续视频文件方式
2019/12/25 Python
使用spring mvc+localResizeIMG实现HTML5端图片压缩上传的功能
2016/12/16 HTML / CSS
美国畅销的跑步机品牌:ProForm
2017/02/06 全球购物
草莓网化妆品澳大利亚站:Strawberrynet AU
2017/12/18 全球购物
德国领先的大尺码和超大尺码男装在线零售商:Bigtex
2019/06/22 全球购物
Smilodox官方运动服装店:从运动服到健身配件
2020/08/27 全球购物
乡镇总工会学雷锋活动总结
2014/03/01 职场文书
党员先锋岗事迹材料
2014/05/08 职场文书
2014年秋季开学演讲稿
2014/05/24 职场文书
领导班子三严三实对照检查材料
2014/09/25 职场文书
绍兴鲁迅故居导游词
2015/02/09 职场文书
MySQL query_cache_type 参数与使用详解
2021/07/01 MySQL
spring cloud 配置中心客户端启动遇到的问题
2021/09/25 Java/Android
Node.js实现爬取网站图片的示例代码
2022/04/04 NodeJs