pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python求解平方根的方法
Mar 11 Python
python判断图片宽度和高度后删除图片的方法
May 22 Python
关于Python面向对象编程的知识点总结
Feb 14 Python
用python做一个搜索引擎(Pylucene)的实例代码
Jul 05 Python
详解python OpenCV学习笔记之直方图均衡化
Feb 08 Python
numpy给array增加维度np.newaxis的实例
Nov 01 Python
Python模块的加载讲解
Jan 15 Python
详解Python基础random模块随机数的生成
Mar 23 Python
解决安装python3.7.4报错Can''t connect to HTTPS URL because the SSL module is not available
Jul 31 Python
python try except返回异常的信息字符串代码实例
Aug 15 Python
python 中关于pycharm选择运行环境的问题
Oct 31 Python
python 合并多个excel中同名的sheet
Jan 22 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
如何使用“PHP” 彩蛋进行敏感信息获取
2013/08/07 PHP
php简单生成随机数的方法
2015/07/30 PHP
PHP使用 Imagick 扩展实现图片合成,圆角处理功能示例
2019/09/09 PHP
jquery.combobox中文api和例子,修复了上面的小bug
2011/03/28 Javascript
Extjs中ComboBox加载并赋初值的实现方法
2012/03/22 Javascript
javascript实现数字验证码的简单实例
2014/02/10 Javascript
javascript自定义函数参数传递为字符串格式
2014/07/29 Javascript
通用javascript代码判断版本号是否在版本范围之间
2015/11/29 Javascript
jQuery Ajax 全局调用封装实例代码详解
2016/06/02 Javascript
利用Javascript实现一套自定义事件机制
2017/12/14 Javascript
详解Vue中使用Echarts的两种方式
2018/07/03 Javascript
vue.js提交按钮时进行简单的if判断表达式详解
2018/08/08 Javascript
如何给element添加一个抽屉组件的方法步骤
2019/07/14 Javascript
简单上手Python中装饰器的使用
2015/07/12 Python
使用Python的Twisted框架构建非阻塞下载程序的实例教程
2016/05/25 Python
windows下安装Python和pip终极图文教程
2017/03/05 Python
Python学习小技巧之利用字典的默认行为
2017/05/20 Python
python实现对列表中的元素进行倒序打印
2019/11/23 Python
PyQt5高级界面控件之QTableWidget的具体使用方法
2020/02/23 Python
python实现银行实战系统
2020/02/26 Python
PyInstaller将Python文件打包为exe后如何反编译(破解源码)以及防止反编译
2020/04/15 Python
如何使用Python进行PDF图片识别OCR
2021/01/22 Python
使用CSS3设计地图上的雷达定位提示效果
2016/04/05 HTML / CSS
阻止移动设备(手机、pad)浏览器双击放大网页的方法
2014/06/03 HTML / CSS
美国知名平价彩妆品牌:e.l.f. Cosmetics
2017/11/20 全球购物
运动鞋、街头服装、手表和手袋的实时市场:StockX
2020/11/25 全球购物
个人自我鉴定范文
2013/10/04 职场文书
财务出纳岗位职责
2014/02/03 职场文书
楼面经理岗位职责范本
2014/02/18 职场文书
《难忘的泼水节》教学反思
2014/02/27 职场文书
第一批党的群众路线教育实践活动工作总结
2014/03/03 职场文书
篮球比赛策划方案
2014/06/05 职场文书
2015年超市收银员工作总结
2015/04/25 职场文书
2015双创工作总结
2015/07/24 职场文书
企业文化学习心得体会
2016/01/21 职场文书
MySQL 开窗函数
2022/02/15 MySQL