pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python回调函数用法实例详解
Jul 02 Python
python 安装virtualenv和virtualenvwrapper的方法
Jan 13 Python
Python中列表list以及list与数组array的相互转换实现方法
Sep 22 Python
详解Python 解压缩文件
Apr 09 Python
python pygame实现方向键控制小球
May 17 Python
Python3enumrate和range对比及示例详解
Jul 13 Python
python3.6生成器yield用法实例分析
Aug 23 Python
pycharm无法导入本地模块的解决方式
Feb 12 Python
python求最大公约数和最小公倍数的简单方法
Feb 13 Python
Python中remove漏删和索引越界问题的解决
Mar 18 Python
python topk()函数求最大和最小值实例
Apr 02 Python
Matplotlib中rcParams使用方法
Jan 05 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
php中使用Curl、socket、file_get_contents三种方法POST提交数据
2011/08/12 PHP
php判断文件夹是否存在不存在则创建
2015/04/09 PHP
PHP IDE PHPStorm配置支持友好Laravel代码提示方法
2015/05/12 PHP
PHP自动识别当前使用移动终端
2018/05/21 PHP
将jQuery应用于login页面的问题及解决
2009/10/17 Javascript
使用JavaScript库还是自己写代码?
2010/01/28 Javascript
浅析XMLHttpRequest的缓存问题
2013/12/13 Javascript
node.js中的emitter.on方法使用说明
2014/12/10 Javascript
JavaScript实现表格点击排序的方法
2015/05/11 Javascript
深入学习JavaScript的AngularJS框架中指令的使用方法
2016/03/05 Javascript
用jQuery向div中添加Html文本内容的简单实现
2016/07/13 Javascript
JS实现保留n位小数的四舍五入问题示例
2016/08/03 Javascript
轻松学习Javascript闭包
2017/03/01 Javascript
JavaScript设计模式之策略模式详解
2017/06/09 Javascript
浅谈Node.js ORM框架Sequlize之表间关系
2017/07/24 Javascript
AngularJS实现的2048小游戏功能【附源码下载】
2018/01/03 Javascript
Phaser.js实现简单的跑酷游戏附源码下载
2018/10/26 Javascript
vue.js页面加载执行created,mounted的先后顺序说明
2020/11/07 Javascript
原生JS实现拖拽功能
2020/12/16 Javascript
[51:22]Fnatic vs IG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
[00:15]天涯墨客终极技能展示
2018/08/25 DOTA
python机器学习实战之最近邻kNN分类器
2017/12/20 Python
Python3实现的判断环形链表算法示例
2019/03/07 Python
详解Python解决抓取内容乱码问题(decode和encode解码)
2019/03/29 Python
Python文件读写w+和r+区别解析
2020/03/26 Python
详解用Pytest+Allure生成漂亮的HTML图形化测试报告
2020/03/31 Python
Python接口测试文件上传实例解析
2020/05/22 Python
销售自荐信
2013/10/22 职场文书
建筑人员岗位职责
2013/12/25 职场文书
党的群众路线教育实践活动个人自我剖析材料
2014/10/07 职场文书
2014年电厂个人工作总结
2014/11/27 职场文书
高中生打架检讨书1000字
2015/02/17 职场文书
如何写辞职书
2015/02/26 职场文书
如何书写读后感?(附范文)
2019/07/26 职场文书
带你彻底理解JavaScript中的原型对象
2021/04/14 Javascript
JS 4个超级实用的小技巧 提升开发效率
2021/10/05 Javascript