pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中测试访问同一数据的竞争条件的方法
Apr 23 Python
python 计算文件的md5值实例
Jan 13 Python
Python 列表(List) 的三种遍历方法实例 详解
Apr 15 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 Python
利用python和ffmpeg 批量将其他图片转换为.yuv格式的方法
Jan 08 Python
python3.7 使用pymssql往sqlserver插入数据的方法
Jul 08 Python
Python 实现数据结构-循环队列的操作方法
Jul 17 Python
python字典的遍历3种方法详解
Aug 10 Python
Pyspark读取parquet数据过程解析
Mar 27 Python
python查找特定名称文件并按序号、文件名分行打印输出的方法
Apr 24 Python
Python Sqlalchemy如何实现select for update
Oct 12 Python
Python实现归一化算法详情
Mar 18 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
PHP date函数参数详解
2006/11/27 PHP
php 文章采集正则代码
2009/12/28 PHP
基于php实现长连接的方法与注意事项的问题
2013/05/10 PHP
关于PHPDocument 代码注释规范的总结
2013/06/25 PHP
浅谈php自定义错误日志
2015/02/13 PHP
php通过文件头判断格式的方法
2016/05/28 PHP
CI框架封装的常用图像处理方法(缩略图,水印,旋转,上传等)
2016/11/22 PHP
使用jQuery轻松实现Ajax的实例代码
2010/08/16 Javascript
充分发挥Node.js程序性能的一些方法介绍
2015/06/23 Javascript
基于ajax实现文件上传并显示进度条
2015/08/03 Javascript
jQuery动态生成不规则表格(前后端)
2017/02/21 Javascript
jQuery实现table表格信息的展开和缩小功能示例
2018/07/21 jQuery
webpack4.0 入门实践教程
2018/10/08 Javascript
javascript数组去重方法总结(推荐)
2019/03/20 Javascript
微信小程序发布新版本时自动提示用户更新的方法
2019/06/07 Javascript
Vue 实例事件简单示例
2019/09/19 Javascript
微信小程序修改checkbox的样式代码实例
2020/01/21 Javascript
详解Vue2的diff算法
2021/01/06 Vue.js
[01:50]WODOTA制作 DOTA2中文宣传片《HERO》
2013/04/28 DOTA
[01:52]2020年DOTA2 TI10夏季活动预告片
2020/07/15 DOTA
利用python获得时间的实例说明
2013/03/25 Python
Python中每次处理一个字符的5种方法
2015/05/21 Python
python版opencv摄像头人脸实时检测方法
2018/08/03 Python
Pytorch保存模型用于测试和用于继续训练的区别详解
2020/01/10 Python
关于keras.layers.Conv1D的kernel_size参数使用介绍
2020/05/22 Python
在Python3.74+PyCharm2020.1 x64中安装使用Kivy的详细教程
2020/08/07 Python
Python大批量搜索引擎图像爬虫工具详解
2020/11/16 Python
Window10上Tensorflow的安装(CPU和GPU版本)
2020/12/15 Python
CSS3制作皮卡丘动画壁纸的示例
2020/11/02 HTML / CSS
婚假请假条格式及范文
2014/04/10 职场文书
工作疏忽、懈怠的检讨书
2014/09/11 职场文书
公务员年度考核评语
2014/12/31 职场文书
同学聚会通知短信
2015/04/20 职场文书
未婚证明格式
2015/06/15 职场文书
《我的伯父鲁迅先生》教学反思
2016/02/16 职场文书
剑指Offer之Java算法习题精讲二叉树专项训练
2022/03/21 Java/Android