pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
50行代码实现贪吃蛇(具体思路及代码)
Apr 27 Python
Python中的Matplotlib模块入门教程
Apr 15 Python
用Python程序抓取网页的HTML信息的一个小实例
May 02 Python
Django中模型Model添加JSON类型字段的方法
Jun 17 Python
使用python Fabric动态修改远程机器hosts的方法
Oct 26 Python
python实现几种归一化方法(Normalization Method)
Jul 31 Python
django创建简单的页面响应实例教程
Sep 06 Python
python函数map()和partial()的知识点总结
May 26 Python
PyTorch-GPU加速实例
Jun 23 Python
python3.8动态人脸识别的实现示例
Sep 21 Python
PyChon中关于Jekins的详细安装(推荐)
Dec 28 Python
python time.strptime格式化实例详解
Feb 03 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
php模板之Phpbean的目录结构
2008/01/10 PHP
php生成xml简单实例代码
2009/12/16 PHP
PHP为表单获取的URL 地址预设 http 字符串函数代码
2010/05/26 PHP
ThinkPHP分组下自定义标签库实例
2014/11/01 PHP
yii2.0数据库迁移教程【多个数据库同时同步数据】
2016/10/08 PHP
PHP+MYSQL实现读写分离简单实战
2017/03/13 PHP
javascript 设为首页与加入收藏兼容多浏览器代码
2011/01/11 Javascript
从零开始学习jQuery (十一) 实战表单验证与自动完成提示插件
2011/02/23 Javascript
通过下拉框的值来确定输入框是否可以为空的代码
2011/10/18 Javascript
js confirm()方法的使用方法实例
2013/07/13 Javascript
js中arguments的用法(实例讲解)
2013/11/30 Javascript
JS实现两表格里数据来回转移的方法
2015/05/28 Javascript
JQuery Mobile实现导航栏和页脚
2016/03/09 Javascript
jQuery 监控键盘一段时间没输入
2016/04/22 Javascript
3kb jQuery代码搞定各种树形选择的实现方法
2016/06/10 Javascript
微信小程序-详解数据缓存
2016/11/24 Javascript
Python操作CouchDB数据库简单示例
2015/03/10 Python
Windows系统下多版本pip的共存问题详解
2017/10/10 Python
Window环境下Scrapy开发环境搭建
2018/11/18 Python
详解python中的hashlib模块的使用
2019/04/22 Python
Python中新式类与经典类的区别详析
2019/07/10 Python
python 实现GUI(图形用户界面)编程详解
2019/07/17 Python
Python3加密解密库Crypto的RSA加解密和签名/验签实现方法实例
2020/02/11 Python
python脚本实现mp4中的音频提取并保存在原目录
2020/02/27 Python
美国儿童珠宝在线零售商:Loveivy
2019/05/22 全球购物
Eton丹麦官网:精美的男式衬衫
2020/05/27 全球购物
城市规划毕业生求职信
2013/10/10 职场文书
高级护理专业大学生求职信
2013/10/24 职场文书
竞选劳动委员演讲稿
2014/04/28 职场文书
疾病防治方案
2014/05/31 职场文书
2014年巴西世界杯口号
2014/06/05 职场文书
贫困证明书格式及范文
2014/10/15 职场文书
幼儿园教师暑期培训心得体会
2016/01/09 职场文书
2016教师学习党章心得体会
2016/01/15 职场文书
读后感怎么写?书写读后感的基本技巧!
2019/12/10 职场文书
python实现双链表
2022/05/25 Python