使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用的文件及文件路径、目录操作方法汇总介绍
May 21 Python
Python3实现从指定路径查找文件的方法
May 22 Python
决策树剪枝算法的python实现方法详解
Sep 18 Python
深入浅析Python科学计算库Scipy及安装步骤
Oct 12 Python
python使用yield压平嵌套字典的超简单方法
Nov 02 Python
关于pandas的离散化,面元划分详解
Nov 22 Python
使用Django实现把两个模型类的数据聚合在一起
Mar 28 Python
python安装后的目录在哪里
Jun 21 Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 Python
五分钟带你搞懂python 迭代器与生成器
Aug 30 Python
教你使用pyinstaller打包Python教程
May 27 Python
Python  lambda匿名函数和三元运算符
Apr 19 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
现磨咖啡骗局!现磨咖啡=新鲜咖啡?现磨咖啡背后的猫腻你不懂!
2019/03/28 冲泡冲煮
PHP使用PDO连接ACCESS数据库
2015/03/05 PHP
php数组索引与键值操作技巧实例分析
2015/06/24 PHP
laravel学习笔记之模型事件的几种用法示例
2017/08/15 PHP
jQuery 使用手册(六)
2009/09/23 Javascript
关于jQuery中的end()使用方法
2011/07/10 Javascript
JavaScript 参数中的数组展开 [译]
2012/09/21 Javascript
js简单实现根据身份证号码识别性别年龄生日
2013/11/29 Javascript
扩展IE中一些不兼容的方法如contains、startWith等等
2014/01/09 Javascript
jQuery响应鼠标事件并隐藏与显示input默认值
2014/08/24 Javascript
angularJS 中$scope方法使用指南
2015/02/09 Javascript
获取select的value、text值的简单示例(jquery与javascript)
2016/12/07 Javascript
js实现开启密码大写提示
2016/12/21 Javascript
Jquery鼠标放上去显示全名的实现方法
2017/02/06 Javascript
json前后端数据交互相关代码
2018/09/19 Javascript
对angularJs中$sce服务安全显示html文本的实例
2018/09/30 Javascript
js核心基础之闭包的应用实例分析
2019/05/11 Javascript
vue实现在线学生录入系统
2020/05/30 Javascript
python实现将汉字转换成汉语拼音的库
2015/05/05 Python
python自定义类并使用的方法
2015/05/07 Python
回调函数的意义以及python实现实例
2017/06/20 Python
Python守护线程用法实例
2017/06/23 Python
Python GUI Tkinter简单实现个性签名设计
2018/06/19 Python
Python动态导入模块的方法实例分析
2018/06/28 Python
python+opencv实现车牌定位功能(实例代码)
2019/12/24 Python
浅谈在JupyterNotebook下导入自己的模块的问题
2020/04/16 Python
python 使用递归的方式实现语义图片分割功能
2020/07/16 Python
CSS3 清除浮动的方法示例
2018/06/01 HTML / CSS
草莓网化妆品加拿大网站:Strawberrynet Canada
2016/09/20 全球购物
华为俄罗斯官方网上商城:购买Huawei手机和平板
2017/04/21 全球购物
巴西服装和鞋子购物网站:Marisa
2018/10/25 全球购物
加大码胸罩、内裤和服装:Just My Size
2019/03/21 全球购物
大学生两会精神学习心得体会
2014/03/10 职场文书
团日活动总结格式
2015/05/11 职场文书
法律服务所工作总结
2015/08/10 职场文书
MongoDB使用场景总结
2022/02/24 MongoDB