使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现数据库并行读取和写入实例
Jun 09 Python
python的多重继承的理解
Aug 06 Python
Python面向对象编程基础解析(二)
Oct 26 Python
人工智能最火编程语言 Python大战Java!
Nov 13 Python
浅谈Python使用Bottle来提供一个简单的web服务
Dec 27 Python
Python实现平行坐标图的两种方法小结
Jul 04 Python
python2.7的flask框架之引用js&css等静态文件的实现方法
Aug 22 Python
利用rest framework搭建Django API过程解析
Aug 31 Python
python装饰器原理与用法深入详解
Dec 19 Python
Python求平面内点到直线距离的实现
Jan 19 Python
浅谈django channels 路由误导
May 28 Python
如何基于matlab相机标定导出xml文件
Nov 02 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
PHP $_SERVER详解
2009/01/16 PHP
php adodb介绍
2009/03/19 PHP
php和mysql中uft-8中文编码乱码的几种解决办法
2012/04/19 PHP
三个类概括PHP的五种设计模式
2012/09/05 PHP
PHP微框架Dispatch简介
2014/06/12 PHP
php抽象类使用要点与注意事项分析
2015/02/09 PHP
以文件形式缓存php变量的方法
2015/06/26 PHP
用JS实现一个页面多个css样式实现
2008/05/29 Javascript
jQuery 验证插件 Web前端设计模式(asp.net)
2010/10/17 Javascript
Jquery css函数用法(判断标签是否拥有某属性)
2011/05/28 Javascript
JS实现多物体缓冲运动实例代码
2013/11/29 Javascript
jquery实现html页面 div 假分页有原理有代码
2014/09/06 Javascript
jQuery老黄历完整实现方法
2015/01/16 Javascript
jquery实现简单的banner轮播效果【实例】
2016/03/30 Javascript
JavaScript入门教程之引用类型
2016/05/04 Javascript
vue.js移动端tab组件的封装实践实例
2017/06/30 Javascript
让div运动起来 js实现缓动效果
2017/07/06 Javascript
jQuery实现QQ空间汉字转拼音功能示例
2017/07/10 jQuery
Webpack框架核心概念(知识点整理)
2017/12/22 Javascript
js实现二级联动简单实例
2020/01/11 Javascript
three.js中多线程的使用及性能测试详解
2021/01/07 Javascript
Python基类函数的重载与调用实例分析
2015/01/12 Python
python中的字典操作及字典函数
2018/01/03 Python
Python 创建新文件时避免覆盖已有的同名文件的解决方法
2018/11/16 Python
python读取目录下所有的jpg文件,并显示第一张图片的示例
2019/06/13 Python
Argos官网:英国家喻户晓的百货零售连锁商
2017/04/03 全球购物
匡威爱尔兰官网:Converse爱尔兰
2019/06/09 全球购物
*p++ 自增p 还是p所指向的变量
2016/07/16 面试题
英语系本科生求职信范文
2013/12/18 职场文书
《盲人摸象》教学反思
2014/02/16 职场文书
中华美德颂演讲稿
2014/05/20 职场文书
设立有限责任公司出资协议书
2014/11/01 职场文书
2015年医院工作总结范文
2015/04/09 职场文书
2015年学校食堂工作总结
2015/04/22 职场文书
2015年法务工作总结范文
2015/05/23 职场文书
房地产置业顾问工作总结
2015/10/23 职场文书