使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现随机密码字典生成器示例
Apr 09 Python
python实现微信远程控制电脑
Feb 22 Python
python获取网页中所有图片并筛选指定分辨率的方法
Mar 31 Python
Python字符串的一些操作方法总结
Jun 10 Python
pycharm 批量修改变量名称的方法
Aug 01 Python
你还在@微信官方?聊聊Python生成你想要的微信头像
Sep 25 Python
python实现简易淘宝购物
Nov 22 Python
python创建ArcGIS shape文件的实现
Dec 06 Python
python求前n个阶乘的和实例
Apr 02 Python
基于Python中random.sample()的替代方案
May 23 Python
Python实现8种常用抽样方法
Jun 27 Python
请求模块urllib之PYTHON爬虫的基本使用
Apr 08 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
PHP脚本的10个技巧(3)
2006/10/09 PHP
php数据结构 算法(PHP描述) 简单选择排序 simple selection sort
2011/08/09 PHP
PHP 图片上传代码
2011/09/13 PHP
php实现读取和写入tab分割的文件
2015/06/01 PHP
CodeIgniter与PHP5.6的兼容问题
2015/07/16 PHP
asp.net和asp下ACCESS的参数化查询
2008/06/11 Javascript
JAVASCRIPT style 中visibility和display之间的区别
2010/01/22 Javascript
jQuery中 noConflict() 方法使用
2013/04/25 Javascript
JS定时器使用,定时定点,固定时刻,循环执行详解
2016/05/31 Javascript
js数组的五种迭代方法及两种归并方法(推荐)
2016/06/14 Javascript
javaScript语法总结
2016/11/25 Javascript
JavaScript自定义浏览器滚动条兼容IE、 火狐和chrome
2017/01/05 Javascript
jquery操作select取值赋值与设置选中实例
2017/02/28 Javascript
jquery 一键复制到剪切板的实例
2017/09/20 jQuery
详解JavaScript中操作符和表达式
2018/09/12 Javascript
在Web关闭页面时发送Ajax请求的实现方法
2019/03/07 Javascript
js prototype和__proto__的关系是什么
2019/08/23 Javascript
[01:16:12]完美世界DOTA2联赛PWL S2 FTD vs Inki 第一场 11.21
2020/11/23 DOTA
Python常用算法学习基础教程
2017/04/13 Python
Python使用getpass库读取密码的示例
2017/10/10 Python
python编写微信远程控制电脑的程序
2018/01/05 Python
selenium+python实现自动登录脚本
2018/04/22 Python
Python3匿名函数lambda介绍与使用示例
2019/05/18 Python
python进度条显示之tqmd模块
2020/08/22 Python
详解python os.path.exists判断文件或文件夹是否存在
2020/11/16 Python
python解压zip包中文乱码解决方法
2020/11/27 Python
介绍一下Prototype的$()函数,$F()函数,$A()函数都是什么作用?
2014/03/05 面试题
社区志愿者心得体会
2014/01/03 职场文书
副总经理党的群众路线教育实践活动个人对照检查材料思想汇报
2014/10/06 职场文书
2015年银行客户经理工作总结
2015/04/01 职场文书
母亲去世追悼词
2015/06/23 职场文书
《圆的周长》教学反思
2016/02/17 职场文书
2016年少先队活动总结
2016/04/06 职场文书
古诗之感恩老师
2019/10/24 职场文书
利用 SQL Server 过滤索引提高查询语句的性能分析
2021/07/15 SQL Server
JavaScript 与 TypeScript之间的联系
2021/11/27 Javascript