pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python处理文本文件中控制字符的方法
Feb 07 Python
Python中is与==判断的区别
Mar 28 Python
Python3实现爬取指定百度贴吧页面并保存页面数据生成本地文档的方法
Apr 22 Python
django rest framework 实现用户登录认证详解
Jul 29 Python
scikit-learn线性回归,多元回归,多项式回归的实现
Aug 29 Python
Pytorch实现神经网络的分类方式
Jan 08 Python
ITK 实现多张图像转成单个nii.gz或mha文件案例
Jul 01 Python
详解Python中的编码问题(encoding与decode、str与bytes)
Sep 30 Python
Python暴力破解Mysql数据的示例
Nov 09 Python
基于Pytorch版yolov5的滑块验证码破解思路详解
Feb 25 Python
如何使用Python提取Chrome浏览器保存的密码
Jun 09 Python
python解析照片拍摄时间进行图片整理
Jul 23 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
PHP设计模式之装饰者模式
2012/02/29 PHP
PHP封装的一个支持HTML、JS、PHP重定向的多功能跳转函数
2014/06/19 PHP
PHP图片处理之图片背景、画布操作
2014/11/19 PHP
PHP实现的线索二叉树及二叉树遍历方法详解
2016/04/25 PHP
PHP常用文件操作函数和简单实例分析
2016/06/03 PHP
php使用自带dom扩展进行元素匹配的原理解析
2020/05/29 PHP
让你的网站可编辑的实现js代码
2009/10/19 Javascript
js中判断文本框是否为空的两种方法
2011/07/31 Javascript
toggle()隐藏问题的解决方法
2014/02/17 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
继续学习javascript闭包
2015/12/03 Javascript
基于javascript实现动态显示当前系统时间
2016/01/28 Javascript
angularjs 中$apply,$digest,$watch详解
2016/10/13 Javascript
Node错误处理笔记之挖坑系列教程
2018/06/05 Javascript
详解关于JSON.parse()和JSON.stringify()的性能小测试
2019/03/14 Javascript
微信小程序云开发获取文件夹下所有文件(推荐)
2019/11/14 Javascript
JS XMLHttpRequest原理与使用方法深入详解
2020/04/30 Javascript
详解js中的原型,原型对象,原型链
2020/07/16 Javascript
Element-ui树形控件el-tree自定义增删改和局部刷新及懒加载操作
2020/08/31 Javascript
浏览器JavaScript调试功能无法使用解决方案
2020/09/18 Javascript
python使用7z解压软件备份文件脚本分享
2014/02/21 Python
Python如何判断数独是否合法
2016/09/08 Python
Pandas 对Dataframe结构排序的实现方法
2018/04/10 Python
python通过tcp发送xml报文的方法
2018/12/28 Python
python print输出延时,让其立刻输出的方法
2019/01/07 Python
pycharm访问mysql数据库的方法步骤
2019/06/18 Python
python读出当前时间精度到秒的代码
2019/07/05 Python
Python中顺序表原理与实现方法详解
2019/12/03 Python
利用pandas将非数值数据转换成数值的方式
2019/12/18 Python
Python 如何定义匿名或内联函数
2020/08/01 Python
Python 测试框架unittest和pytest的优劣
2020/09/26 Python
美国婚礼和派对礼品网站:Kate Aspen(新娘送礼会、迎婴派对)
2018/03/28 全球购物
爱护花草树木的标语
2014/06/11 职场文书
建筑工地标语
2014/06/18 职场文书
学校总务处领导干部个人对照检查材料思想汇报
2014/10/06 职场文书
公安局班子个人对照检查材料思想汇报
2014/10/09 职场文书