使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
windows下wxPython开发环境安装与配置方法
Jun 28 Python
老生常谈python之鸭子类和多态
Jun 13 Python
python中的内置函数max()和min()及mas()函数的高级用法
Mar 29 Python
Python占用的内存优化教程
Jul 28 Python
pytorch 可视化feature map的示例代码
Aug 20 Python
在python中创建指定大小的多维数组方式
Nov 28 Python
python 实现按对象传值
Dec 26 Python
np.dot()函数的用法详解
Jan 17 Python
基于python 等频分箱qcut问题的解决
Mar 03 Python
vscode写python时的代码错误提醒和自动格式化的方法
May 07 Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 Python
据Python爬虫不靠谱预测可知今年双十一销售额将超过6000亿元
Nov 11 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
PHP关联链接常用代码
2012/11/05 PHP
解析如何去掉CodeIgniter URL中的index.php
2013/06/25 PHP
PHP生成静态HTML页面最简单方法示例
2015/04/09 PHP
利用PHP将图片转换成base64编码的实现方法
2016/09/13 PHP
PHP实现限制IP访问的方法
2017/04/20 PHP
php实现文章评论系统
2019/02/18 PHP
ext form 表单提交数据的方法小结
2008/08/08 Javascript
jquery tablesorter.js 支持中文表格排序改进
2009/12/09 Javascript
从URL中提取参数与将对象转换为URL查询参数的实现代码
2012/01/12 Javascript
广泛收集的jQuery拖放插件集合
2012/04/09 Javascript
Javascript算符的优先级介绍
2013/03/20 Javascript
Javascript基础教程之while语句
2015/01/18 Javascript
IE中document.createElement的iframe无法设置属性name的解决方法
2015/09/14 Javascript
Kindeditor在线文本编辑器如何过滤HTML
2016/04/14 Javascript
第一章之初识Bootstrap
2016/04/25 Javascript
简单总结JavaScript中的String字符串类型
2016/05/26 Javascript
关于动态执行代码(js的Eval)实例详解
2016/08/15 Javascript
jQuery Ajax File Upload实例源码
2016/12/12 Javascript
nodejs实现大文件(在线视频)的读取
2020/10/16 NodeJs
详解element-ui日期时间选择器的日期格式化问题
2019/04/08 Javascript
layui 实现加载动画以及非真实加载进度的方法
2019/09/23 Javascript
在vue中动态添加class类进行显示隐藏实例
2019/11/09 Javascript
从Python的源码浅要剖析Python的内存管理
2015/04/16 Python
python通过socket查询whois的方法
2015/07/18 Python
详解python eval函数的妙用
2017/11/16 Python
Django admin实现图书管理系统菜鸟级教程完整实例
2017/12/12 Python
详解PyTorch批训练及优化器比较
2018/04/28 Python
如何利用Python开发一个简单的猜数字游戏
2019/09/22 Python
浅谈keras 的抽象后端(from keras import backend as K)
2020/06/16 Python
Smashbox英国官网:美国知名彩妆品牌
2017/11/13 全球购物
罗技英国官方网站:Logitech UK
2020/11/03 全球购物
宣传策划类求职信范文
2014/01/31 职场文书
全国道德模范事迹
2014/02/01 职场文书
长城导游词300字
2015/01/30 职场文书
项目负责人岗位职责
2015/02/15 职场文书
Nginx虚拟主机的搭建的实现步骤
2022/01/18 Servers