使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中实现一个高性能计数器(Counter)实例
Jul 09 Python
简单的Python的curses库使用教程
Apr 11 Python
Python实现对象转换为xml的方法示例
Jun 08 Python
python中返回矩阵的行列方法
Apr 04 Python
从请求到响应过程中django都做了哪些处理
Aug 01 Python
解决在Python编辑器pycharm中程序run正常debug错误的问题
Jan 17 Python
Python importlib动态导入模块实现代码
Apr 16 Python
python实现最短路径的实例方法
Jul 19 Python
Python批量删除mysql中千万级大量数据的脚本分享
Dec 03 Python
Python try except else使用详解
Jan 12 Python
Python 把两层列表展开平铺成一层(5种实现方式)
Apr 07 Python
Python语法学习之进程的创建与常用方法详解
Apr 08 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
用PHP读取和编写XML DOM的实现代码
2011/02/03 PHP
laravel实现分页样式替换示例代码(增加首、尾页)
2017/09/22 PHP
JavaScript 应用类库代码
2008/06/02 Javascript
location.search在客户端获取Url参数的方法
2010/06/08 Javascript
理解Javascript_08_函数对象
2010/10/15 Javascript
情人节专属 纯js脚本1k大小的3D玫瑰效果
2012/02/11 Javascript
使用jQuery避免鼠标双击的解决方案
2013/08/21 Javascript
ComboBox 和 DateField 在IE下消失的解决方法
2013/08/30 Javascript
JavaScript fontsize方法入门实例(按照指定的尺寸来显示字符串)
2014/10/17 Javascript
JS实现网页表格自动变大缩小的方法
2015/03/09 Javascript
分享JavaScript与Java中MD5使用两个例子
2015/12/23 Javascript
textarea 在浏览器中固定大小和禁止拖动的实现方法
2016/12/03 Javascript
jQuery实现字符串全部替换的方法【推荐】
2017/03/09 Javascript
JS库之Waypoints的用法详解
2017/09/13 Javascript
Node.js使用cookie保持登录的方法
2018/05/11 Javascript
vue  自定义组件实现通讯录功能
2018/09/30 Javascript
js根据json数据中的某一个属性来给数据分组的方法
2018/10/08 Javascript
webpack4+react多页面架构的实现
2018/10/25 Javascript
对layui数据表格动态cols(字段)动态变化详解
2019/10/25 Javascript
JS实现时间校验的代码
2020/05/25 Javascript
在vant中使用时间选择器和popup弹出层的操作
2020/11/04 Javascript
[02:16]卖萌的僵尸 DOTA2神话信使飞僵小宝来袭
2014/03/24 DOTA
[14:21]VICI vs EG (BO3)
2018/06/07 DOTA
使用Python求解最大公约数的实现方法
2015/08/20 Python
基于pandas将类别属性转化为数值属性的方法
2018/07/25 Python
对PyQt5中树结构的实现方法详解
2019/06/17 Python
matlab中二维插值函数interp2的使用详解
2020/04/22 Python
python安装和pycharm环境搭建设置方法
2020/05/27 Python
python如何实现图片压缩
2020/09/11 Python
Happy Plugs官网:瑞典无线耳机品牌
2020/07/16 全球购物
生产车间班组长岗位职责
2014/01/06 职场文书
信息技术专业大学生职业生涯规划书
2014/01/24 职场文书
青年教师典范事迹材料
2014/01/31 职场文书
小区门卫岗位职责范本
2014/08/24 职场文书
Python-typing: 类型标注与支持 Any类型详解
2021/05/10 Python
springcloud之Feign超时问题的解决
2021/06/24 Java/Android