pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
Python的pycurl包用法简介
Nov 13 Python
Python中元组,列表,字典的区别
May 21 Python
python 寻找优化使成本函数最小的最优解的方法
Dec 28 Python
python2.7+selenium2实现淘宝滑块自动认证功能
Feb 24 Python
删除DataFrame中值全为NaN或者包含有NaN的列或行方法
Nov 06 Python
用Python实现大文本文件切割的方法
Jan 12 Python
实例讲解Python中浮点型的基本内容
Feb 11 Python
利用python计算时间差(返回天数)
Sep 07 Python
django中url映射规则和服务端响应顺序的实现
Apr 02 Python
Python的历史与优缺点整理
May 26 Python
分析Python感知线程状态的解决方案之Event与信号量
Jun 16 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
ThinkPHP CURD方法之page方法详解
2014/06/18 PHP
PHP实现图片上传并压缩
2015/12/22 PHP
PHP中使用foreach()遍历二维数组的简单实例
2016/06/13 PHP
JavaScript 常见对象类创建代码与优缺点分析
2009/12/07 Javascript
combox改进版 页面原型参考dojo的,比网上jQuery的那些combox功能强,代码更小
2010/04/15 Javascript
MC Dialog js弹出层 完美兼容多浏览器(5.6更新)
2010/05/06 Javascript
jquery 实现上下滚动效果示例代码
2013/08/09 Javascript
js文件Cookie存取值示例代码
2014/02/20 Javascript
javascript删除元素节点removeChild()用法实例
2015/05/26 Javascript
使用JavaScript实现旋转的彩圈特效
2015/06/23 Javascript
jQuery文本框得到与失去焦点动态改变样式效果
2016/09/08 Javascript
AngularJS自定义指令实现面包屑功能完整实例
2017/05/17 Javascript
基于pako.js实现gzip的压缩和解压功能示例
2017/06/13 Javascript
MUI实现上拉加载和下拉刷新效果
2017/06/30 Javascript
jQuery Datatables表头不对齐的解决办法
2017/11/27 jQuery
layui实现数据表格table分页功能(ajax异步)
2019/07/27 Javascript
解决layui动态添加的元素click等事件触发不了的问题
2019/09/20 Javascript
vue基于better-scroll实现左右联动滑动页面
2020/06/30 Javascript
如何基于viewport vm适配移动端页面
2020/11/13 Javascript
jQuery冲突问题解决方法
2021/01/19 jQuery
Python爬虫抓取手机APP的传输数据
2016/01/22 Python
特征脸(Eigenface)理论基础之PCA主成分分析法
2018/03/13 Python
Python3.5 + sklearn利用SVM自动识别字母验证码方法示例
2019/05/10 Python
Python符号计算之实现函数极限的方法
2019/07/15 Python
python批量将excel内容进行翻译写入功能
2019/10/10 Python
Pandas中DataFrame交换列顺序的方法实现
2020/12/14 Python
肯尼亚网上商城:Kilimall
2016/08/20 全球购物
TCP/IP的分层模型
2013/10/27 面试题
营销总经理的岗位职责
2013/12/15 职场文书
医药工作岗位求职信分享
2013/12/31 职场文书
幼儿园教学随笔感言
2014/02/23 职场文书
学习十八大报告感言
2014/02/28 职场文书
民主生活会发言材料
2014/10/20 职场文书
和领导吃饭祝酒词
2015/08/11 职场文书
你会写请假条吗?
2019/06/26 职场文书
Windows10安装Apache2.4的方法步骤
2022/06/25 Servers