pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
调试Python程序代码的几种方法总结
Apr 28 Python
用Python实现一个简单的能够上传下载的HTTP服务器
May 05 Python
python实现可以断点续传和并发的ftp程序
Sep 13 Python
Python使用内置json模块解析json格式数据的方法
Jul 20 Python
Python编程使用NLTK进行自然语言处理详解
Nov 16 Python
Python第三方Window模块文件的几种安装方法
Nov 22 Python
对python以16进制打印字节数组的方法详解
Jan 24 Python
python实现知乎高颜值图片爬取
Aug 12 Python
python Pillow图像处理方法汇总
Oct 16 Python
python3实现elasticsearch批量更新数据
Dec 03 Python
Python 一行代码能实现丧心病狂的功能
Jan 18 Python
OpenCV图像变换之傅里叶变换的一些应用
Jul 26 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
比较全的PHP 会话(session 时间设定)使用入门代码
2008/06/05 PHP
php学习笔记 PHP面向对象的程序设计
2011/06/13 PHP
php之readdir函数用法实例
2014/11/13 PHP
PHP不使用内置函数实现字符串转整型的方法示例
2017/07/03 PHP
实用javaScript技术-屏蔽类
2006/08/15 Javascript
动态加载js文件 document.createElement
2006/10/14 Javascript
JavaScript中的History历史对象
2008/01/16 Javascript
javascript针对DOM的应用实例(一)
2012/04/15 Javascript
Jjcarousellite 实现图片列表滚动的简单实例
2013/11/29 Javascript
node.js中的fs.lstat方法使用说明
2014/12/16 Javascript
JavaScript中利用jQuery绑定事件的几种方式小结
2016/03/06 Javascript
js中new一个对象的过程
2017/02/20 Javascript
在JS中如何把毫秒转换成规定的日期时间格式实例
2017/05/11 Javascript
angular select 默认值设置方法
2017/06/23 Javascript
图片懒加载imgLazyLoading.js使用详解
2020/09/15 Javascript
vuejs 制作背景淡入淡出切换动画的实例
2018/09/01 Javascript
利用React Router4实现的服务端直出渲染(SSR)
2019/01/07 Javascript
用JS实现一个简单的打砖块游戏
2019/12/11 Javascript
Python爬豆瓣电影实例
2018/02/23 Python
Python实现K折交叉验证法的方法步骤
2019/07/11 Python
python Pandas如何对数据集随机抽样
2019/07/29 Python
keras实现基于孪生网络的图片相似度计算方式
2020/06/11 Python
解决PyCharm IDE环境下,执行unittest不生成测试报告的问题
2020/09/03 Python
英国鞋类及配饰零售商:Kurt Geiger
2017/02/04 全球购物
美国波西米亚风格精品店:South Moon Under
2019/10/26 全球购物
总经理秘书的岗位职责
2013/12/27 职场文书
竞聘演讲稿范文
2014/01/12 职场文书
奉献爱心演讲稿
2014/09/04 职场文书
党员批评与自我批评思想汇报
2014/10/08 职场文书
八年级英语教学计划
2015/01/23 职场文书
保研推荐信范文
2015/03/25 职场文书
爱心捐书倡议书
2015/04/27 职场文书
归途列车观后感
2015/06/17 职场文书
Go语言应该什么情况使用指针
2021/07/25 Golang
sqlserver连接错误之SQL评估期已过的问题解决
2022/03/23 SQL Server
Win10玩csgo闪退如何解决?Win10玩csgo闪退的解决方法
2022/07/23 数码科技