pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python遍历目录的4种方法实例介绍
Apr 13 Python
Python数据类型详解(一)字符串
May 08 Python
Python的re模块正则表达式操作
May 25 Python
python flask实现分页效果
Jun 27 Python
Python cookbook(数据结构与算法)从字典中提取子集的方法示例
Mar 22 Python
Python3 jupyter notebook 服务器搭建过程
Nov 30 Python
Python队列、进程间通信、线程案例
Oct 25 Python
如何基于Python批量下载音乐
Nov 11 Python
Python如何读取文件中图片格式
Jan 13 Python
Django中的DateTimeField和DateField实现
Feb 24 Python
只用Python就可以制作的简单词云
Jun 07 Python
Python爬虫框架之Scrapy中Spider的用法
Jun 28 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
2款PHP无限级分类实例代码
2015/11/11 PHP
PHP文件与目录操作示例
2016/12/24 PHP
PHP+Ajax简单get验证操作示例
2019/03/02 PHP
JavaScript 继承详解(二)
2009/07/13 Javascript
javascript实现的基于金山词霸网络翻译的代码
2010/01/15 Javascript
AlertBox 弹出层信息提示框效果实现步骤
2010/10/11 Javascript
突发奇想的一个jquery插件
2010/11/19 Javascript
js中的onchange和onpropertychange (onchange无效的解决方法)
2014/03/08 Javascript
jquery插件之定时查询待处理任务数量
2014/05/01 Javascript
Jquery图片延迟加载插件jquery.lazyload.js的使用方法
2014/05/21 Javascript
jQuery实现单击按钮遮罩弹出对话框效果(2)
2017/02/20 Javascript
React-Native 组件之 Modal的使用详解
2017/08/08 Javascript
windows下更新npm和node的方法
2017/11/30 Javascript
Vue中使用的EventBus有生命周期
2018/07/12 Javascript
基于Angular中ng-controller父子级嵌套的相关属性详解
2018/10/08 Javascript
微信公众平台 发送模板消息(Java接口开发)
2019/04/17 Javascript
Vue+Element实现表格编辑、删除、以及新增行的最优方法
2019/05/28 Javascript
[06:24]DOTA2亚洲邀请赛小组赛第三日 TOP10精彩集锦
2015/02/01 DOTA
[01:35]辉夜杯战队访谈宣传片—LGD
2015/12/25 DOTA
在Python操作时间和日期之asctime()方法的使用
2015/05/22 Python
Python os模块学习笔记
2015/06/21 Python
Python实现更改图片尺寸大小的方法(基于Pillow包)
2016/09/19 Python
Python+OpenCV图片局部区域像素值处理改进版详解
2019/01/23 Python
python定时按日期备份MySQL数据并压缩
2019/04/19 Python
Python super()函数使用及多重继承
2020/05/06 Python
Asics日本官网:鬼冢八喜郎创立的跑鞋运动品牌
2017/10/18 全球购物
eVitamins日本:在线购买折扣维生素、补品和草药
2019/04/04 全球购物
美国牙科折扣计划:DentalPlans.com
2019/08/26 全球购物
英国领先的男装设计师服装独立零售商:Repertoire Fashion
2020/10/19 全球购物
如何安装ruby on rails
2014/02/09 面试题
企业党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
2015年法制宣传月活动总结
2015/03/26 职场文书
盗窃案辩护词
2015/05/21 职场文书
高一数学教学反思
2016/02/18 职场文书
CSS实现多个元素在盒子内两端对齐效果
2021/03/30 HTML / CSS
血轮眼轮回眼特效 html+css
2021/03/31 HTML / CSS