pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python向日志输出中添加上下文信息
May 24 Python
python读取中文txt文本的方法
Apr 12 Python
python3 字符串/列表/元组(str/list/tuple)相互转换方法及join()函数的使用
Apr 03 Python
详解python中的time和datetime的常用方法
Jul 08 Python
python设置环境变量的作用和实例
Jul 09 Python
python自动发微信监控报警
Sep 06 Python
softmax及python实现过程解析
Sep 30 Python
基于pandas中expand的作用详解
Dec 17 Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 Python
python pandas.DataFrame.loc函数使用详解
Mar 26 Python
python,Java,JavaScript实现indexOf
Sep 09 Python
python 机器学习的标准化、归一化、正则化、离散化和白化
Apr 16 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
用PHP连接Oracle数据库
2006/10/09 PHP
建立动态的WML站点(三)
2006/10/09 PHP
is_uploaded_file函数引发的不能上传文件问题
2013/10/29 PHP
php使用exec shell命令注入的方法讲解
2013/11/12 PHP
php Imagick获取图片RGB颜色值
2014/07/28 PHP
php利用云片网实现短信验证码功能的示例代码
2017/11/18 PHP
jquery select选中的一个小问题
2009/10/11 Javascript
js 实现打印网页中定义的部分内容的代码
2010/04/01 Javascript
contains和compareDocumentPosition 方法来确定是否HTML节点间的关系
2011/09/13 Javascript
jQuery 阴影插件代码分享
2012/01/09 Javascript
javascript带回调函数的异步脚本载入方法实例分析
2015/07/02 Javascript
jQuery Uploadify 上传插件出现Http Error 302 错误的解决办法
2015/12/12 Javascript
浅谈jQuery中事情的动态绑定
2017/02/12 Javascript
基于Vue实现拖拽效果
2018/04/27 Javascript
微信小程序用户位置权限的获取方法(拒绝后提醒)
2018/11/15 Javascript
微信小程序 wx:for遍历循环使用实例解析
2019/09/09 Javascript
Python闭包实现计数器的方法
2015/05/05 Python
python输出决策树图形的例子
2019/08/09 Python
解决pytorch DataLoader num_workers出现的问题
2020/01/14 Python
通过实例了解Python异常处理机制底层实现
2020/07/23 Python
多个版本的python共存时使用pip的正确做法
2020/10/26 Python
HTML5单页面手势滑屏切换原理
2016/03/21 HTML / CSS
美国婚礼和派对礼品网站:Kate Aspen(新娘送礼会、迎婴派对)
2018/03/28 全球购物
Koral官方网站:女性时尚运动服
2019/04/10 全球购物
在职党员进社区活动总结
2014/07/05 职场文书
初中学习计划书范文
2014/09/15 职场文书
党员作风建设自查报告
2014/10/23 职场文书
2014年保洁员工作总结
2014/11/19 职场文书
见习期个人总结
2015/03/05 职场文书
学生乘坐校车安全责任书
2015/05/11 职场文书
严以律己专题学习研讨会发言材料
2015/11/09 职场文书
python 实现的截屏工具
2021/05/08 Python
浅谈redis整数集为什么不能降级
2021/07/25 Redis
SQL Server2019数据库备份与还原脚本,数据库可批量备份
2021/11/20 SQL Server
Windows Server 2012配置DNS服务器的方法
2022/04/29 Servers
MySQL外键约束(Foreign Key)案例详解
2022/06/28 MySQL