pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用PIL库实现验证码图片的方法
Mar 11 Python
python爬虫入门教程--正则表达式完全指南(五)
May 25 Python
Python的语言类型(详解)
Jun 24 Python
R语言 vs Python对比:数据分析哪家强?
Nov 17 Python
python引入不同文件夹下的自定义模块方法
Oct 27 Python
Python设计模式之模板方法模式实例详解
Jan 17 Python
分析经典Python开发工程师面试题
Apr 08 Python
linux下python中文乱码解决方案详解
Aug 28 Python
基于python实现文件加密功能
Jan 06 Python
Python读取Excel一列并计算所有对象出现次数的方法
Sep 04 Python
Python求区间正整数内所有素数之和的方法实例
Oct 13 Python
学会Python数据可视化必须尝试这7个库
Jun 16 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
回首过去10年中最搞笑的10部动漫,哪一部让你节操尽碎?
2020/03/03 日漫
php导入导出excel实例
2013/10/25 PHP
10个超级有用值得收藏的PHP代码片段
2015/01/22 PHP
PHP实现权限管理功能示例
2017/09/22 PHP
如何实现浏览器上的右键菜单
2006/07/10 Javascript
js函数中onmousedown和onclick的区别和联系探讨
2013/05/19 Javascript
javascript中数组的concat()方法使用介绍
2013/12/18 Javascript
jQuery对Select的操作大集合(收藏)
2013/12/28 Javascript
javascript面向对象之共享成员属性与方法及prototype关键字用法
2015/01/13 Javascript
Windows系统中安装nodejs图文教程
2015/02/28 NodeJs
详解vue-Resource(与后端数据交互)
2017/01/16 Javascript
JS实现多级菜单中当前菜单不随页面跳转样式而发生变化
2017/05/30 Javascript
JS实现前端页面的搜索功能
2018/06/12 Javascript
微信小程序支付PHP代码
2018/08/23 Javascript
手动下载Chrome并解决puppeteer无法使用问题
2018/11/12 Javascript
移动端滑动切换组件封装 vue-swiper-router实例详解
2018/11/25 Javascript
vue-better-scroll 的使用实例代码详解
2018/12/03 Javascript
js中call()和apply()改变指针问题的讲解
2019/01/17 Javascript
Vue实现商品分类菜单数量提示功能
2019/07/26 Javascript
uni-app 组件里面获取元素宽高的实现
2019/12/27 Javascript
详解VUE中的插值( Interpolation)语法
2020/10/18 Javascript
关于Vue中$refs的探索浅析
2020/11/05 Javascript
python中的编码知识整理汇总
2016/01/26 Python
Python用Bottle轻量级框架进行Web开发
2016/06/08 Python
pandas 透视表中文字段排序方法
2018/11/16 Python
Python学习笔记之字符串和字符串方法实例详解
2019/08/22 Python
使用PyCharm进行远程开发和调试的实现
2019/11/04 Python
Python 线性回归分析以及评价指标详解
2020/04/02 Python
如何以Winsows Service方式运行JupyterLab
2020/08/30 Python
GUESS德国官网:美国牛仔服装品牌
2017/02/14 全球购物
约瑟夫·特纳男装:Joseph Turner
2017/10/10 全球购物
JSP和Servlet有哪些相同点和不同点,他们之间的联系是什么?
2015/10/22 面试题
中秋晚会活动方案
2014/08/31 职场文书
党员对照检查材料思想汇报(党的群众路线)
2014/09/24 职场文书
2014党员整改措施思想汇报
2014/10/07 职场文书
Win11怎么把合并的任务栏分开 Win11任务栏合并分开教程
2022/04/06 数码科技