pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
基于Python的关键字监控及告警
Jul 06 Python
详解appium+python 启动一个app步骤
Dec 20 Python
学习python中matplotlib绘图设置坐标轴刻度、文本
Feb 07 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 Python
将tensorflow的ckpt模型存储为npy的实例
Jul 09 Python
python线程定时器Timer实现原理解析
Nov 30 Python
Python tkinter布局与按钮间距设置方式
Mar 04 Python
Python爬虫实现百度翻译功能过程详解
May 29 Python
python解压zip包中文乱码解决方法
Nov 27 Python
Pandas中两个dataframe的交集和差集的示例代码
Dec 13 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
使用HMAC-SHA1签名方法详解
2013/06/26 PHP
php基于jquery的ajax技术传递json数据简单实例
2016/04/15 PHP
PHP自带方法验证邮箱、URL、IP是否合法的函数
2016/12/08 PHP
PHP设计模式之装饰器模式定义与用法详解
2018/04/02 PHP
js监听鼠标事件控制textarea输入字符串的个数
2014/09/29 Javascript
js中document.write的那点事
2014/12/12 Javascript
JS实现的页面自定义滚动条效果
2015/10/26 Javascript
Javascript必知必会(四)js类型转换
2016/06/08 Javascript
AngularJs Javascript MVC 框架
2016/06/20 Javascript
BootStrap轮播HTML代码(推荐)
2016/12/10 Javascript
微信小程序(六):列表上拉加载下拉刷新示例
2017/01/13 Javascript
JS基于正则表达式实现的密码强度验证功能示例
2017/09/21 Javascript
vue2组件之select2调用的示例代码
2017/10/12 Javascript
vue使用axios跨域请求数据问题详解
2017/10/18 Javascript
vue中使用GraphQL的实例代码
2019/11/04 Javascript
纯js实现无缝滚动功能代码实例
2020/02/21 Javascript
vue父子组件间引用之$parent、$children
2020/05/20 Javascript
[02:32]DOTA2英雄基础教程 美杜莎
2014/01/07 DOTA
[48:18]DOTA2-DPC中国联赛 正赛 RNG vs Dynasty BO3 第二场 1月29日
2021/03/11 DOTA
Python深入学习之对象的属性
2014/08/31 Python
python实现在windows下操作word的方法
2015/04/28 Python
Python中数字以及算数运算符的相关使用
2015/10/12 Python
go和python变量赋值遇到的一个问题
2017/08/31 Python
对python读取CT医学图像的实例详解
2019/01/24 Python
使用tqdm显示Python代码执行进度功能
2019/12/08 Python
pytorch AvgPool2d函数使用详解
2020/01/03 Python
python3.8.1+selenium实现登录滑块验证功能
2020/05/22 Python
详解CSS3:overflow属性
2020/11/17 HTML / CSS
英国电器零售商:PRC Direct
2018/06/21 全球购物
私有程序集与共享程序集有什么区别
2013/04/05 面试题
渡河少年教学反思
2014/02/12 职场文书
年会搞笑主持词
2014/03/27 职场文书
《梅花魂》教学反思
2014/04/30 职场文书
庆国庆活动总结
2014/08/28 职场文书
师德承诺书2015
2015/04/28 职场文书
小学英语听课心得体会
2016/01/14 职场文书