使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取当前用户的主目录路径方法(推荐)
Jan 12 Python
python读取二进制mnist实例详解
May 31 Python
Python设置在shell脚本中自动补全功能的方法
Jun 25 Python
python调用虹软2.0第三版的具体使用
Feb 22 Python
Python根据当前日期取去年同星期日期
Apr 14 Python
Python实现微信消息防撤回功能的实例代码
Apr 29 Python
使用python绘制温度变化雷达图
Oct 18 Python
浅谈Python访问MySQL的正确姿势
Jan 07 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 Python
django-利用session机制实现唯一登录的例子
Mar 16 Python
Python函数参数定义及传递方式解析
Jun 10 Python
Python爬虫基础之初次使用scrapy爬虫实例
Jun 26 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
PHPMYADMIN导入数据最大为2M的解决方法
2012/04/23 PHP
PHP中CURL方法curl_setopt()函数的参数分享
2013/01/19 PHP
解析PHP生成静态html文件的三种方法
2013/06/18 PHP
php 中的closure用法详解
2017/06/12 PHP
使用EXT实现无刷新动态调用股票信息
2008/11/01 Javascript
javascript 二分法(数组array)
2010/04/24 Javascript
Three.js源码阅读笔记(基础的核心Core对象)
2012/12/27 Javascript
js比较和逻辑运算符的介绍
2013/03/10 Javascript
JS代码同步文本框内容的实例方法
2013/07/12 Javascript
node.js中的fs.readdirSync方法使用说明
2014/12/17 Javascript
angular.js指令中的controller、compile与link函数的不同之处
2017/05/10 Javascript
js动态设置select下拉菜单的默认选中项实例
2018/08/21 Javascript
JS遍历JSON数组及获取JSON数组长度操作示例【测试可用】
2018/12/12 Javascript
vue学习之Vue-Router用法实例分析
2020/01/06 Javascript
react+antd 递归实现树状目录操作
2020/11/02 Javascript
[03:40]2014DOTA2国际邀请赛 B神专访:躲箭真的很难
2014/07/13 DOTA
[01:14:55]EG vs Spirit Supermajor 败者组 BO3 第三场 6.4
2018/06/05 DOTA
浅谈Python中chr、unichr、ord字符函数之间的对比
2016/06/16 Python
详解python基础之while循环及if判断
2017/08/24 Python
opencv python 图像去噪的实现方法
2018/08/31 Python
Python压缩模块zipfile实现原理及用法解析
2020/08/14 Python
Python实现敏感词过滤的4种方法
2020/09/12 Python
python爬虫分布式获取数据的实例方法
2020/11/26 Python
英语专业学生个人求职信范文
2014/01/06 职场文书
我爱读书演讲稿
2014/05/07 职场文书
企业晚会策划方案
2014/05/29 职场文书
爱耳日宣传活动总结
2014/07/05 职场文书
工作目标责任书
2014/07/23 职场文书
2014年度个人工作总结
2014/11/07 职场文书
招标保密承诺书
2015/01/20 职场文书
英文商务邀请函范文
2015/01/31 职场文书
堂吉诃德读书笔记
2015/06/30 职场文书
2015新教师教学工作总结
2015/07/22 职场文书
2016大学生暑期社会实践心得体会
2016/01/14 职场文书
《静夜思》教学反思
2016/02/17 职场文书
十大最强格斗系宝可梦,超梦X仅排第十,第二最重格斗礼仪
2022/03/18 日漫