使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django框架中render_to_response()函数的使用方法
Jul 16 Python
Python编程实战之Oracle数据库操作示例
Jun 21 Python
python下10个简单实例代码
Nov 15 Python
Python中super函数的用法
Nov 17 Python
Python如何获得百度统计API的数据并发送邮件示例代码
Jan 27 Python
Python从函数参数类型引出元组实例分析
May 28 Python
anaconda中更改python版本的方法步骤
Jul 14 Python
python中类的输出或类的实例输出为这种形式的原因
Aug 12 Python
PythonPC客户端自动化实现原理(pywinauto)
May 28 Python
Python爬虫使用bs4方法实现数据解析
Aug 25 Python
Python ellipsis 的用法详解
Nov 20 Python
如何用六步教会你使用python爬虫爬取数据
Apr 06 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
如何把PHP转成EXE文件
2006/10/09 PHP
基于PHP遍历数组的方法汇总分析
2013/06/08 PHP
解析将多维数组转换为支持curl提交的一维数组格式
2013/07/08 PHP
php中使用getimagesize获取图片、flash等文件的尺寸信息实例
2014/04/29 PHP
php使用pclzip类实现文件压缩的方法(附pclzip类下载地址)
2016/04/30 PHP
php实现图片以base64显示的方法
2016/10/13 PHP
PHPExcel中文帮助手册|PHPExcel使用方法(分享)
2017/06/09 PHP
学习YUI.Ext第五日--做拖放Darg&Drop
2007/03/10 Javascript
input的focus方法使用
2010/03/13 Javascript
jQuery+ajax实现顶一下,踩一下效果
2010/07/17 Javascript
js中用window.open()打开多个窗口的name问题
2014/03/13 Javascript
浅谈setTimeout 与 setInterval
2015/06/23 Javascript
JS验证IP,子网掩码,网关和MAC的方法
2015/07/02 Javascript
JavaScript实现页面跳转的方式汇总
2016/05/16 Javascript
JavaScript中return用法示例
2016/11/29 Javascript
BootStrap表单宽度设置方法
2017/03/10 Javascript
ES6中Symbol类型用法实例详解
2017/04/06 Javascript
完美解决浏览器跨域的几种方法(汇总)
2017/05/08 Javascript
基于layer.js实现收货地址弹框选择然后返回相应的地址信息
2017/05/26 Javascript
详解angular中的作用域及继承
2017/05/31 Javascript
vue-router+vuex addRoutes实现路由动态加载及菜单动态加载
2017/09/28 Javascript
Vue中对iframe实现keep alive无刷新的方法
2019/07/23 Javascript
vue下使用nginx刷新页面404的问题解决
2019/08/02 Javascript
Python 实现删除某路径下文件及文件夹的实例讲解
2018/04/24 Python
python对验证码降噪的实现示例代码
2019/11/12 Python
python中with语句结合上下文管理器操作详解
2019/12/19 Python
Python读取YAML文件过程详解
2019/12/30 Python
python异常处理、自定义异常、断言原理与用法分析
2020/03/23 Python
windows10 pycharm下安装pyltp库和加载模型实现语义角色标注的示例代码
2020/05/07 Python
为什么说python更适合树莓派编程
2020/07/20 Python
python logging模块的使用详解
2020/10/23 Python
css3中背景尺寸background-size详解
2014/09/02 HTML / CSS
在canvas上实现元素图片镜像翻转动画效果的方法
2018/03/20 HTML / CSS
日常奢侈品,轻松购物:Verishop
2019/08/20 全球购物
2015年酒店工作总结
2015/04/28 职场文书
准备去美国留学,那么大学申请文书应该怎么写?
2019/08/12 职场文书