使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python升级提示Tkinter模块找不到的解决方法
Aug 22 Python
Python import用法以及与from...import的区别
May 28 Python
解决Python requests 报错方法集锦
Mar 19 Python
Python读取指定目录下指定后缀文件并保存为docx
Apr 23 Python
python微信跳一跳系列之自动计算跳一跳距离
Feb 26 Python
python写入已存在的excel数据实例
May 03 Python
解决win64 Python下安装PIL出错问题(图解)
Sep 03 Python
python使用phoenixdb操作hbase的方法示例
Feb 28 Python
python设计tcp数据包协议类的例子
Jul 23 Python
python实现控制台输出彩色字体
Apr 05 Python
python连接mongodb数据库操作数据示例
Nov 30 Python
python源码剖析之PyObject详解
May 18 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
10个实用的PHP正则表达式汇总
2014/10/23 PHP
php利用ZipArchive类操作文件的实例
2020/01/21 PHP
PHP pthreads v3下同步处理synchronized用法示例
2020/02/21 PHP
如何在Web页面上直接打开、编辑、创建Office文档
2007/03/12 Javascript
dwr spring的集成实现代码
2009/03/22 Javascript
Javascript 学习笔记 错误处理
2009/07/30 Javascript
基于jquery实现发送文章到手机的代码
2014/12/26 Javascript
举例简介AngularJS的内部语言环境
2015/06/17 Javascript
jQuery+CSS实现滑动的标签分栏切换效果
2015/12/17 Javascript
JavaScript实现身份证验证代码
2016/02/17 Javascript
基于jQuery实现Accordion手风琴自定义插件
2020/10/13 Javascript
Nodejs 获取时间加手机标识的32位标识实现代码
2017/03/07 NodeJs
JavaScript 数据类型详解
2017/03/13 Javascript
详解angular用$sce服务来过滤HTML标签
2017/04/11 Javascript
从零开始学习Node.js系列教程三:图片上传和显示方法示例
2017/04/13 Javascript
Angular中的$watch方法详解
2017/09/18 Javascript
bootstrap中日历范围选择插件daterangepicker的使用详解
2018/04/17 Javascript
js动态引入的四种方法
2018/05/05 Javascript
vue动态绑定class选中当前列表变色的方法示例
2018/12/19 Javascript
vue中el-input绑定键盘按键(按键修饰符)
2020/07/22 Javascript
python实现DNS正向查询、反向查询的例子
2014/04/25 Python
跟老齐学Python之坑爹的字符编码
2014/09/28 Python
简化Python的Django框架代码的一些示例
2015/04/20 Python
python将秒数转化为时间格式的实例
2018/09/16 Python
python版飞机大战代码分享
2018/11/20 Python
Python字典的概念及常见应用实例详解
2019/10/30 Python
西班牙英格列斯百货英国官网:El Corte Inglés英国
2017/10/30 全球购物
自我评价范文点评
2013/12/04 职场文书
物业招聘计划书
2014/01/10 职场文书
大学旷课检讨书
2014/01/28 职场文书
主题教育活动总结
2014/05/05 职场文书
cf战队收人口号
2014/06/21 职场文书
美术第二课堂活动总结
2014/07/08 职场文书
区长工作作风个人整改措施
2014/10/01 职场文书
少年派的奇幻漂流观后感
2015/06/08 职场文书
工作自我评价范文
2019/03/21 职场文书