pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
vc6编写python扩展的方法分享
Jan 17 Python
Python基类函数的重载与调用实例分析
Jan 12 Python
Python中的Matplotlib模块入门教程
Apr 15 Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 Python
pygame实现五子棋游戏
Oct 29 Python
学Python 3的理由和必要性
Nov 19 Python
tensorflow没有output结点,存储成pb文件的例子
Jan 04 Python
解决Pycharm中恢复被exclude的项目问题(pycharm source root)
Feb 14 Python
使用python实现飞机大战游戏
Mar 23 Python
Python的in,is和id函数代码实例
Apr 18 Python
Python的logging模块基本用法
Dec 24 Python
Pycharm 跳转回之前所在页面的操作
Feb 05 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
php数组函数序列之array_search()- 按元素值返回键名
2011/11/04 PHP
PHP中Cookie的使用详解(简单易懂)
2017/04/28 PHP
PHP多线程模拟实现秒杀抢单
2018/02/07 PHP
新手入门常用代码集锦
2007/01/11 Javascript
jQuery实现自动切换播放的经典滑动门效果
2015/09/12 Javascript
javascript实现自动填写表单实例简析
2015/12/02 Javascript
JQuery学习总结【二】
2016/12/01 Javascript
jQuery ajax请求struts action实现异步刷新
2017/04/19 jQuery
javascript 开发之百度地图使用到的js函数整理
2017/05/19 Javascript
在vue中通过axios异步使用echarts的方法
2018/01/13 Javascript
Vue中的slot使用插槽分发内容的方法
2018/03/01 Javascript
浅析Vue 生命周期
2018/06/21 Javascript
Electron 调用命令行(cmd)
2019/09/23 Javascript
通过js实现压缩图片上传功能
2020/02/25 Javascript
[01:21:58]守擂赛DOTA2第一周决赛
2020/04/22 DOTA
C#返回当前系统所有可用驱动器符号的方法
2015/04/18 Python
Python基于Pymssql模块实现连接SQL Server数据库的方法详解
2017/07/20 Python
基于使用paramiko执行远程linux主机命令(详解)
2017/10/16 Python
Python输出由1,2,3,4组成的互不相同且无重复的三位数
2018/02/01 Python
对Python 检查文件名是否规范的实例详解
2019/06/10 Python
关于Python内存分配时的小秘密分享
2019/09/05 Python
Docker部署Python爬虫项目的方法步骤
2020/01/19 Python
Pandas将列表(List)转换为数据框(Dataframe)
2020/04/24 Python
pytorch加载语音类自定义数据集的方法教程
2020/11/10 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
2021/01/27 Python
Django与AJAX实现网页动态数据显示的示例代码
2021/02/24 Python
CSS3使用transition属性实现过渡效果
2018/04/18 HTML / CSS
日本一家专门经营各种箱包的大型网站:Traveler Store
2016/08/03 全球购物
Zadig&Voltaire官网:法国时装品牌
2018/01/05 全球购物
介绍一下如何利用路径遍历进行攻击及如何防范
2014/01/19 面试题
心理健康教育制度
2014/01/27 职场文书
后勤部经理岗位职责
2014/02/23 职场文书
充分就业社区汇报材料
2014/05/07 职场文书
信息工作经验交流材料
2014/05/28 职场文书
2015年大学生工作总结
2015/04/21 职场文书
idea搭建可运行Servlet的Web项目
2021/06/26 Java/Android