pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之文本文件的读取和写入
Aug 29 Python
使用Python的package机制如何简化utils包设计详解
Dec 11 Python
Python之web模板应用
Dec 26 Python
TensorFlow实现RNN循环神经网络
Feb 28 Python
使用Scrapy爬取动态数据
Oct 21 Python
Python绘制并保存指定大小图像的方法
Jan 10 Python
Appium+Python自动化测试之运行App程序示例
Jan 23 Python
python双向链表原理与实现方法详解
Dec 03 Python
python如何实现不可变字典inmutabledict
Jan 08 Python
使用keras实现孪生网络中的权值共享教程
Jun 11 Python
Python如何读写CSV文件
Aug 13 Python
Python几种酷炫的进度条的方式
Apr 11 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
PHP 程序员也要学会使用“异常”
2009/06/16 PHP
采用ThinkPHP中F方法实现快速缓存实例
2014/06/13 PHP
WordPress中给媒体文件添加分类和标签的PHP功能实现
2015/12/31 PHP
php实现基于openssl的加密解密方法
2016/09/30 PHP
一个非常实用的php文件上传类
2017/07/04 PHP
Yii2框架中一些折磨人的坑
2019/12/15 PHP
javascript结合fileReader 实现上传图片
2015/01/30 Javascript
JavaScript简单修改窗口大小的方法
2015/08/03 Javascript
chrome调试javascript详解
2015/10/21 Javascript
js+canvas绘制五角星的方法
2016/01/28 Javascript
jQuery 局部div刷新和全局刷新方法总结
2016/10/05 Javascript
jQuery序列化表单成对象的简单实现
2016/11/29 Javascript
Javascript中Promise的四种常用方法总结
2017/07/14 Javascript
js点击时关闭该范围下拉菜单之外的菜单方法
2018/01/11 Javascript
在vue组件中使用axios的方法
2018/03/16 Javascript
jQuery中元素选择器(element)简单用法示例
2018/05/14 jQuery
javascript设计模式 ? 中介者模式原理与用法实例分析
2020/04/20 Javascript
微信小程序开发打开另一个小程序的实现方法
2020/05/17 Javascript
JS+CSS实现动态时钟
2021/02/19 Javascript
[01:56]林书豪DOTA2上海特级锦标赛励志短片
2016/03/05 DOTA
简要讲解Python编程中线程的创建与锁的使用
2016/02/28 Python
python线程中同步锁详解
2018/04/27 Python
python sort、sort_index方法代码实例
2019/03/28 Python
Python控制Firefox方法总结
2019/06/03 Python
CSS3教程(5):网页背景图片
2009/04/02 HTML / CSS
详解使用HTML5 Canvas创建动态粒子网格动画
2016/12/14 HTML / CSS
英国排名第一的最新设计师品牌手表独立零售商:TIC Watches
2016/09/24 全球购物
意大利制造的男鞋和女鞋:SCAROSSO
2018/03/07 全球购物
Fresh馥蕾诗英国官网:法国LVMH集团旗下高端天然护肤品牌
2018/11/01 全球购物
The North Face北面德国官网:美国著名户外品牌
2018/12/12 全球购物
写给妈妈的道歉信
2014/01/11 职场文书
优秀学生干部推荐材料
2014/02/03 职场文书
写给同学的新学期寄语
2015/02/27 职场文书
幼儿园食品安全责任书
2015/05/08 职场文书
2015年食品安全宣传周活动总结
2015/07/09 职场文书
Golang并发操作中常见的读写锁详析
2021/08/30 Golang