pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用tkinter实现ui界面简单实例
Jan 10 Python
Python中实现参数类型检查的简单方法
Apr 21 Python
对Python的Django框架中的项目进行单元测试的方法
Apr 11 Python
python二分查找算法的递归实现方法
May 12 Python
python psutil库安装教程
Mar 19 Python
Python Socket编程之多线程聊天室
Jul 28 Python
selenium在执行phantomjs的API并获取执行结果的方法
Dec 17 Python
Python调用百度根据经纬度查询地址的示例代码
Jul 07 Python
python 字符串常用方法汇总详解
Sep 16 Python
超实用的 30 段 Python 案例
Oct 10 Python
jupyter 使用Pillow包显示图像时inline显示方式
Apr 24 Python
Django中的模型类设计及展示示例详解
May 29 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
PHP 高手之路(一)
2006/10/09 PHP
PHP使用递归生成文章树
2015/04/21 PHP
PHP中的事务使用实例
2015/05/26 PHP
PHP7正式版测试,性能惊艳!
2015/12/08 PHP
建议大家看下JavaScript重要知识更新
2007/07/08 Javascript
jQuery find和children方法使用
2011/01/31 Javascript
原生Js实现元素渐隐/渐现(原理为修改元素的css透明度)
2013/06/24 Javascript
各浏览器对document.getElementById等方法的实现差异解析
2013/12/05 Javascript
js中top的作用深入剖析
2014/03/04 Javascript
JQuery 图片滚动轮播示例代码
2014/03/24 Javascript
用javascript读取xml文件读取节点数据
2014/08/12 Javascript
浅析jQuery Mobile的初始化事件
2015/12/03 Javascript
confirm确认对话框的实现方法总结
2016/06/17 Javascript
BootStrap整体框架之基础布局组件
2016/12/15 Javascript
javascript显示系统当前时间代码
2016/12/29 Javascript
利用nginx + node在阿里云部署https的步骤详解
2017/12/19 Javascript
angular第三方包开发整理(小结)
2018/04/19 Javascript
JavaScript实现的简单加密解密操作示例
2018/06/01 Javascript
微信小程序接入腾讯云验证码的方法步骤
2020/01/07 Javascript
JS XMLHttpRequest原理与使用方法深入详解
2020/04/30 Javascript
js代码实现轮播图
2020/05/04 Javascript
Python os模块介绍
2014/11/30 Python
python爬虫的一个常见简单js反爬详解
2019/07/09 Python
python如何获取apk的packagename和activity
2020/01/10 Python
解决python执行较大excel文件openpyxl慢问题
2020/05/15 Python
python matplotlib库的基本使用
2020/09/23 Python
Python+Xlwings 删除Excel的行和列
2020/12/19 Python
HTML5中的nav标签学习笔记
2016/06/24 HTML / CSS
详解使用HTML5 Canvas创建动态粒子网格动画
2016/12/14 HTML / CSS
法国家具及室内配件店:home24
2017/01/21 全球购物
俄罗斯运动鞋商店:Sneakerhead
2018/05/10 全球购物
英国二手iPhone、音乐、电影和游戏商店:musicMagpie
2018/10/26 全球购物
俄罗斯汽车零件和配件在线商店:CarvilleShop
2019/11/29 全球购物
电气工程及自动化专业自荐书范文
2013/12/18 职场文书
大学生学习计划书
2014/09/15 职场文书
年会邀请函范文
2015/01/30 职场文书