pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用PDB简单调试Python程序简明指南
Apr 25 Python
再谈Python中的字符串与字符编码(推荐)
Dec 14 Python
Python类的动态修改的实例方法
Mar 24 Python
python生成随机图形验证码详解
Nov 08 Python
Python针对给定字符串求解所有子序列是否为回文序列的方法
Apr 21 Python
对Python 网络设备巡检脚本的实例讲解
Apr 22 Python
详解numpy的argmax的具体使用
May 27 Python
python+opencv实现移动侦测(帧差法)
Mar 20 Python
Keras之fit_generator与train_on_batch用法
Jun 17 Python
Python txt文件常用读写操作代码实例
Aug 03 Python
python 自动刷新网页的两种方法
Apr 20 Python
详解Python+OpenCV绘制灰度直方图
Mar 22 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
php通过asort()给关联数组按照值排序的方法
2015/03/18 PHP
PHP中的自动加载操作实现方法详解
2019/08/06 PHP
PHP中通过getopt解析GNU C风格命令行选项
2019/11/18 PHP
学习YUI.Ext 第六天--关于树TreePanel(Part 1)
2007/03/10 Javascript
javascript打印html内容功能的方法示例
2013/11/28 Javascript
JavaScript里四舍五入函数round用法实例
2015/04/06 Javascript
JS+CSS实现六级网站导航主菜单效果
2015/09/28 Javascript
window.setInterval()方法的定义和用法及offsetLeft与style.left的区别
2015/11/11 Javascript
AngularJS指令详解及示例代码
2016/08/16 Javascript
jQuery实现加入收藏夹功能(主流浏览器兼职)
2016/12/24 Javascript
AngularJS  ng-repeat遍历输出的用法
2017/06/19 Javascript
详解Vuejs2.0 如何利用proxyTable实现跨域请求
2017/08/03 Javascript
bootstrap confirmation按钮提示组件使用详解
2017/08/22 Javascript
详解nodeJs文件系统(fs)与流(stream)
2018/01/24 NodeJs
深入剖析Node.js cluster模块
2018/05/23 Javascript
react 国际化的实现代码示例
2018/09/14 Javascript
微信小程序模板template简单用法示例
2018/12/04 Javascript
JavaScript实现图片的放大缩小及拖拽功能示例
2019/05/14 Javascript
深入剖析JavaScript instanceof 运算符
2019/06/14 Javascript
Vue开发中遇到的跨域问题及解决方法
2020/02/11 Javascript
[02:08]2018年度CS GO枪械皮肤设计大赛优秀作者-完美盛典
2018/12/16 DOTA
使用python实现个性化词云的方法
2017/06/16 Python
Python实现多并发访问网站功能示例
2017/06/19 Python
Python实现曲线点抽稀算法的示例
2017/10/12 Python
pytorch使用Variable实现线性回归
2019/05/21 Python
Python基础学习之基本数据结构详解【数字、字符串、列表、元组、集合、字典】
2019/06/18 Python
Python Opencv提取图片中某种颜色组成的图形的方法
2019/09/19 Python
python 使用建议与技巧分享(四)
2020/08/18 Python
django中cookiecutter的使用教程
2020/12/03 Python
RentCars.com巴西:汽车租赁网站
2016/08/22 全球购物
服装行业创业计划书范文
2014/02/05 职场文书
《问银河》教学反思
2014/02/19 职场文书
开展党的群众路线教育实践活动个人对照检查材料
2014/11/05 职场文书
优秀志愿者感言
2015/08/01 职场文书
2016党性教育学习心得体会
2016/01/21 职场文书
Redis模仿手机验证码发送的实现示例
2021/11/02 Redis