pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python最基本的数据类型以及对元组的介绍
Apr 14 Python
Python算法之求n个节点不同二叉树个数
Oct 27 Python
Scrapy抓取京东商品、豆瓣电影及代码分享
Nov 23 Python
基于DataFrame改变列类型的方法
Jul 25 Python
Python实现的排列组合、破解密码算法示例
Apr 12 Python
python SQLAlchemy的Mapping与Declarative详解
Jul 04 Python
Python的几种主动结束程序方式
Nov 22 Python
python随机生成大小写字母数字混合密码(仅20行代码)
Feb 01 Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 Python
python实现三种随机请求头方式
Jan 05 Python
基于PyQT5制作一个桌面摸鱼工具
Feb 15 Python
Python实现双向链表
May 25 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
关于php循环跳出的问题
2013/07/01 PHP
本地机apache配置基于域名的虚拟主机详解
2013/08/10 PHP
php微信高级接口群发 多客服
2016/06/23 PHP
Zend Framework前端控制器用法示例
2016/12/11 PHP
JavaScript 给汉字排序实例代码
2008/06/28 Javascript
Jquery调用webService远程访问出错的解决方法
2010/05/21 Javascript
深入理解JavaScript高级之词法作用域和作用域链
2013/12/10 Javascript
jquery 绑定回车动作扑捉回车键触发的事件
2014/03/26 Javascript
JavaScript仿支付宝密码输入框
2015/12/29 Javascript
js实现简单的计算器功能
2017/01/16 Javascript
微信小程序 向左滑动删除功能的实现
2017/03/10 Javascript
js中less常用的方法小结
2017/08/09 Javascript
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
jQuery实现数字自动增加或者减少的动画效果示例
2018/12/11 jQuery
React+Redux实现简单的待办事项列表ToDoList
2019/09/29 Javascript
[02:10]DOTA2亚洲邀请赛 EG战队出场宣传片
2015/02/07 DOTA
python求众数问题实例
2014/09/26 Python
Python3 能振兴 Python的原因分析
2014/11/28 Python
python实现ping的方法
2015/07/06 Python
Python闭包函数定义与用法分析
2018/07/20 Python
详解用Python实现自动化监控远程服务器
2019/05/18 Python
vscode 配置 python3开发环境的方法
2019/09/19 Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
2020/01/04 Python
详解如何在PyCharm控制台中输出彩色文字和背景
2020/08/17 Python
实现CSS3中的border-radius(边框圆角)示例代码
2013/07/19 HTML / CSS
印尼最大的网上书店:Gramedia.com
2018/09/13 全球购物
与世界上最好的跑步专业品牌合作:Fleet Feet
2019/03/22 全球购物
工商管理专业应届生求职信
2013/11/04 职场文书
安全施工标语
2014/06/07 职场文书
2014国庆节国旗下演讲稿(精选版)
2014/09/26 职场文书
六一儿童节主持开场白
2015/05/28 职场文书
《有余数的除法》教学反思
2016/02/22 职场文书
导游词之新疆-喀纳斯
2019/10/10 职场文书
Nginx中break与last的区别详析
2021/03/31 Servers
Java新手教程之ArrayList的基本使用
2021/06/20 Java/Android
Python面试不修改数组找出重复的数字
2022/05/20 Python