pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
一键搞定python连接mysql驱动有关问题(windows版本)
Apr 23 Python
python使用mysql的两种使用方式
Mar 07 Python
使用pytorch进行图像的顺序读取方法
Jul 27 Python
scrapy-redis源码分析之发送POST请求详解
May 15 Python
python3 tkinter实现点击一个按钮跳出另一个窗口的方法
Jun 13 Python
对Django外键关系的描述
Jul 26 Python
如何用Python来理一理红楼梦里的那些关系
Aug 14 Python
python如何将两个txt文件内容合并
Oct 18 Python
django 多数据库及分库实现方式
Apr 01 Python
Python ORM框架Peewee用法详解
Apr 29 Python
matplotlib bar()实现百分比堆积柱状图
Feb 24 Python
python实现学生通讯录管理系统
Feb 25 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
PHP面向对象详解(三)
2015/12/07 PHP
php 如何获取文件的后缀名
2016/06/05 PHP
PHP怎样用正则抓取页面中的网址
2016/08/09 PHP
php 多文件上传的实现实例
2016/10/23 PHP
PHP接口继承及接口多继承原理与实现方法详解
2017/10/18 PHP
javascript smipleChart 简单图标类
2011/01/12 Javascript
Jquery知识点二 jquery下对数组的操作
2011/01/15 Javascript
浅析IE10兼容性问题(frameset的cols属性)
2014/01/03 Javascript
用json方式实现在 js 中建立一个map
2014/05/02 Javascript
js实现的万能flv网页播放器代码
2016/04/30 Javascript
javascript 小数乘法结果错误的处理方法
2016/07/28 Javascript
jquery实现input框获取焦点的方法
2017/02/06 Javascript
如何将 jQuery 从你的 Bootstrap 项目中移除(取而代之使用Vue.js)
2017/07/17 jQuery
从setTimeout看js函数执行过程
2017/12/19 Javascript
基于Node的Axure文件在线预览的实现代码
2019/08/28 Javascript
微信小程序保持session会话的方法
2020/03/20 Javascript
详解react组件通讯方式(多种)
2020/05/06 Javascript
python判断字符串是否包含子字符串的方法
2015/03/24 Python
浅谈Python中(&,|)和(and,or)之间的区别
2019/08/07 Python
Python中zip()函数的简单用法举例
2019/09/02 Python
Python的条件锁与事件共享详解
2019/09/12 Python
jupyter notebook 实现matplotlib图动态刷新
2020/04/22 Python
Python参数传递对象的引用原理解析
2020/05/22 Python
python unichr函数知识点总结
2020/12/16 Python
pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异
2021/02/25 Python
固特异美国在线轮胎店:Goodyear Tire
2019/02/23 全球购物
LightInTheBox法国站:中国跨境电商
2020/03/05 全球购物
大学生职业生涯规划书范文
2014/01/14 职场文书
新文化运动的基本口号
2014/06/21 职场文书
2014初中数学教研组工作总结
2014/12/19 职场文书
初中重阳节活动总结
2015/05/05 职场文书
医生行业员工的辞职信
2019/06/24 职场文书
如何书写公司员工保密协议?
2019/06/27 职场文书
高中议论文(范文2篇)
2019/08/19 职场文书
在 SQL 语句中处理 NULL 值的方法
2021/06/07 SQL Server
MySQL创建定时任务
2022/01/22 MySQL