使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python Django连接MySQL数据库做增删改查
Nov 07 Python
详解Python中的装饰器、闭包和functools的教程
Apr 02 Python
Python 装饰器使用详解
Jul 29 Python
python 列表,数组,矩阵两两转换tolist()的实例
Apr 04 Python
使用python Fabric动态修改远程机器hosts的方法
Oct 26 Python
利用Python对文件夹下图片数据进行批量改名的代码实例
Feb 21 Python
OpenCV HSV颜色识别及HSV基本颜色分量范围
Mar 22 Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 Python
Python统计分析模块statistics用法示例
Sep 06 Python
Anaconda 查看、创建、管理和使用python环境的方法
Dec 03 Python
Python中的特殊方法以及应用详解
Sep 20 Python
Python基础之函数嵌套知识总结
May 23 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
全国FM电台频率大全 - 22 重庆市
2020/03/11 无线电
分享最受欢迎的5款PHP框架
2014/11/27 PHP
PHP开发中AJAX技术的简单应用
2015/12/11 PHP
PHP身份证校验码计算方法
2016/08/10 PHP
详解PHP处理密码的几种方式
2016/11/30 PHP
BOOM vs RR BO5 第三场 2.14
2021/03/10 DOTA
javascript中的undefined 与 null 的区别  补充篇
2010/03/17 Javascript
JS 模态对话框和非模态对话框操作技巧汇总
2013/04/15 Javascript
jquery获取css中的选择器(实例讲解)
2013/12/02 Javascript
javascript版的in_array函数(判断数组中是否存在特定值)
2014/05/09 Javascript
JavaScript实现网页对象拖放功能的方法
2015/04/15 Javascript
JavaScript中的数据类型转换方法小结
2015/10/26 Javascript
javascript拖拽应用实例
2016/03/25 Javascript
老生常谈JQuery data方法的使用
2016/09/09 Javascript
vuejs开发组件分享之H5图片上传、压缩及拍照旋转的问题处理
2017/03/06 Javascript
jQuery实现注册会员时密码强度提示信息功能示例
2017/09/05 jQuery
js微信分享接口调用详解
2019/07/23 Javascript
webpack.DefinePlugin与cross-env区别详解
2020/02/23 Javascript
理解JavaScript中的Proxy 与 Reflection API
2020/09/21 Javascript
Vue实现手机号、验证码登录(60s禁用倒计时)
2020/12/19 Vue.js
[48:21]Mski vs VGJ.S Supermajor小组赛C组 BO3 第一场 6.3
2018/06/04 DOTA
[41:52]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第二场 2月22日
2021/03/11 DOTA
python中的yield使用方法
2014/02/11 Python
python实现在windows下操作word的方法
2015/04/28 Python
在Linux系统上通过uWSGI配置Nginx+Python环境的教程
2015/12/25 Python
python基于plotly实现画饼状图代码实例
2019/12/16 Python
Python pip install如何修改默认下载路径
2020/04/29 Python
Python unittest单元测试openpyxl实现过程解析
2020/05/27 Python
Pytorch实现WGAN用于动漫头像生成
2021/03/04 Python
技校生自我鉴定
2013/12/08 职场文书
会计电算化专业毕业生推荐信
2013/12/24 职场文书
村优秀党员事迹材料
2014/01/15 职场文书
美德好少年事迹材料
2014/01/19 职场文书
工程承诺书怎么写
2014/05/24 职场文书
社区结对共建协议书
2016/03/23 职场文书
Redis 彻底禁用RDB持久化操作
2021/07/09 Redis