使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python条件和循环的使用方法
Nov 01 Python
Python 安装setuptools和pip工具操作方法(必看)
May 22 Python
Python学习思维导图(必看篇)
Jun 26 Python
Python中创建二维数组
Oct 17 Python
Python实现常见的回文字符串算法
Nov 14 Python
opencv实现图片模糊和锐化操作
Nov 19 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
Aug 07 Python
python实现删除列表中某个元素的3种方法
Jan 15 Python
Python字符串格式化常用手段及注意事项
Jun 17 Python
Python用dilb提取照片上人脸的示例
Oct 26 Python
python入门教程之基本算术运算符
Nov 13 Python
python读取excel数据并且画图的实现示例
Feb 08 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
php set_time_limit()函数的使用详解
2013/06/05 PHP
PHP基于imagick扩展实现合成图片的两种方法【附imagick扩展下载】
2017/11/14 PHP
表单元素事件 (Form Element Events)
2009/07/17 Javascript
Javascript面向对象扩展库代码分享
2012/03/27 Javascript
javascript 使用 NodeList需要注意的问题
2013/03/04 Javascript
浅析ajax请求json数据并用js解析(示例分析)
2013/07/13 Javascript
js 通用订单代码
2013/12/23 Javascript
javascript属性访问表达式用法分析
2015/04/25 Javascript
AngularJS入门教程之静态模板详解
2016/08/18 Javascript
第一次接触Bootstrap框架
2016/10/24 Javascript
JS类的定义与使用方法深入探索
2016/11/26 Javascript
angular ng-repeat数组中的数组实例
2017/02/18 Javascript
超简单的Vue.js环境搭建教程
2017/03/17 Javascript
vue获取input输入值的问题解决办法
2017/10/17 Javascript
Three.js实现3D机房效果
2018/12/30 Javascript
如何实现一个简易版的vuex持久化工具
2019/09/11 Javascript
使用Node.js在深度学习中做图片预处理的方法
2019/09/18 Javascript
原生js实现随机点名功能
2019/11/05 Javascript
Python 自动刷博客浏览量实例代码
2017/06/14 Python
Python常见字符串操作函数小结【split()、join()、strip()】
2018/02/02 Python
python脚本作为Windows服务启动代码详解
2018/02/11 Python
Python管理Windows服务小脚本
2018/03/12 Python
Python unittest 简单实现参数化的方法
2018/11/30 Python
Python识别快递条形码及Tesseract-OCR使用详解
2019/07/15 Python
python中plt.imshow与cv2.imshow显示颜色问题
2020/07/16 Python
解决pip安装的第三方包在PyCharm无法导入的问题
2020/10/15 Python
学会迭代器设计模式,帮你大幅提升python性能
2021/01/03 Python
迪拜航空官方网站:flydubai
2017/04/20 全球购物
美国葡萄酒网上商店:Martha Stewart Wine Co.
2019/03/17 全球购物
联想智利官方网站:Lenovo Chile
2020/06/03 全球购物
自我评价怎么写好呢?
2013/12/05 职场文书
五一服装活动方案
2014/01/11 职场文书
购房协议书范本
2014/10/02 职场文书
三年级学生评语大全
2014/12/26 职场文书
2016年教师学习廉政准则心得体会
2016/01/20 职场文书
python实现一个简单的贪吃蛇游戏附代码
2022/06/28 Python