pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的id()函数解密过程
Dec 25 Python
Python使用正则匹配实现抓图代码分享
Apr 02 Python
详解Python中expandtabs()方法的使用
May 18 Python
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
Python设置默认编码为utf8的方法
Jul 01 Python
Python网络爬虫之爬取微博热搜
Apr 18 Python
matplotlib quiver箭图绘制案例
Apr 17 Python
Pandas将列表(List)转换为数据框(Dataframe)
Apr 24 Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 Python
Python 高效编程技巧分享
Sep 10 Python
用python对excel进行操作(读,写,修改)
Dec 25 Python
numpy数据类型dtype转换实现
Apr 24 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
世界咖啡生产者论坛呼吁:需要立即就咖啡价格采取认真行动
2021/03/06 咖啡文化
用Socket发送电子邮件(利用需要验证的SMTP服务器)
2006/10/09 PHP
COM in PHP (winows only)
2006/10/09 PHP
PHP 文件上传进度条的两种实现方法的代码
2007/11/25 PHP
linux下删除7天前日志的代码(php+shell)
2011/01/02 PHP
ajax 的post方法实例(带循环)
2011/07/04 PHP
TP5多入口设置实例讲解
2020/12/15 PHP
JavaScript prototype属性使用说明
2010/05/13 Javascript
JavaScript的常见兼容问题及相关解决方法(chrome/IE/firefox)
2013/12/31 Javascript
javascript通过获取html标签属性class实现多选项卡的方法
2015/07/27 Javascript
request请求获取参数的实现方法(post和get两种方式)
2016/09/27 Javascript
浅谈jquery上下滑动的注意事项
2016/10/13 Javascript
微信小程序教程之本地图片上传(leancloud)实例详解
2016/11/16 Javascript
微信小程序 开发之滑块视图容器(swiper)详解及实例代码
2017/02/22 Javascript
vue解决跨域路由冲突问题思路解析
2017/11/03 Javascript
JS实现可用滑块滑动的缓动图代码
2019/09/01 Javascript
[46:00]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#4EG VS Fnatic第一局
2016/03/03 DOTA
[44:50]2018DOTA2亚洲邀请赛 4.1 小组赛 A组 TNC vs VG
2018/04/02 DOTA
python提取字典key列表的方法
2015/07/11 Python
浅谈Python实现贪心算法与活动安排问题
2017/12/19 Python
Python使用Shelve保存对象方法总结
2019/01/28 Python
梅尔频率倒谱系数(mfcc)及Python实现
2019/06/18 Python
python Elasticsearch索引建立和数据的上传详解
2019/08/04 Python
matplotlib命令与格式之tick坐标轴日期格式(设置日期主副刻度)
2019/08/06 Python
Python接口测试文件上传实例解析
2020/05/22 Python
销售人员自我评价怎么写
2013/09/19 职场文书
实习生自荐信范文分享
2013/11/27 职场文书
化工专业推荐信范文
2013/11/28 职场文书
六一儿童节活动策划方案
2014/01/27 职场文书
总经理岗位职责范本
2014/02/02 职场文书
电焊工岗位职责
2014/03/06 职场文书
公司保密管理制度
2015/08/04 职场文书
创业计划书之个人工作室
2019/08/22 职场文书
python基础入门之字典和集合
2021/06/13 Python
Win11运行育碧游戏总是崩溃怎么办 win11玩育碧游戏出现性能崩溃的解决办法
2022/04/06 数码科技
shell进度条追踪指令执行时间的场景分析
2022/06/16 Servers