pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python中的作用域规则详解
Jan 30 Python
举例讲解Python面向对象编程中类的继承
Jun 17 Python
python中print的不换行即时输出的快速解决方法
Jul 20 Python
Python代码解决RenderView窗口not found问题
Aug 28 Python
Python实现SSH远程登陆,并执行命令的方法(分享)
May 08 Python
基于python list对象中嵌套元组使用sort时的排序方法
Apr 18 Python
Python3.5 处理文本txt,删除不需要的行方法
Dec 10 Python
Python 运行 shell 获取输出结果的实例
Jan 07 Python
Django日志及中间件模块应用案例
Sep 10 Python
python os.rename实例用法详解
Dec 06 Python
Python快速实现一键抠图功能的全过程
Jun 29 Python
Python jiaba库的使用详解
Nov 23 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
通过缓存数据库结果提高PHP性能的原理介绍
2012/09/05 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(三)
2014/06/23 PHP
Yii2中关联查询简单用法示例
2016/08/10 PHP
PHP实现的登录,注册及密码修改功能分析
2016/11/25 PHP
php简单实现文件或图片强制下载的方法
2016/12/06 PHP
PHP+Ajax 检测网络是否正常实例详解
2016/12/16 PHP
FireFox与IE 下js兼容触发click事件的代码
2008/11/20 Javascript
Extjs TriggerField在弹出窗口显示不出问题的解决方法
2010/01/08 Javascript
JQuery开发的数独游戏代码
2010/10/29 Javascript
Jquery实现的tab效果可以指定默认显示第几页
2013/10/16 Javascript
JS版的date函数(和PHP的date函数一样)
2014/05/12 Javascript
JS实现的另类手风琴效果网页内容切换代码
2015/09/08 Javascript
js 上传文件预览的简单实例
2016/08/16 Javascript
修改Jquery Dialog 位置的实现方法
2016/08/26 Javascript
JS调用某段SQL语句的方法
2016/10/20 Javascript
jQuery 插件封装的方法
2016/11/16 Javascript
AngularJS中的JSONP实例解析
2016/12/01 Javascript
如何提高Dom访问速度
2017/01/05 Javascript
Vue.js实战之利用vue-router实现跳转页面
2017/04/01 Javascript
在Create React App中启用Sass和Less的方法示例
2019/01/16 Javascript
node.js监听文件变化的实现方法
2019/04/17 Javascript
[48:48]完美世界DOTA2联赛PWL S3 Magama vs GXR 第一场 12.19
2020/12/24 DOTA
Python3学习urllib的使用方法示例
2017/11/29 Python
python按时间排序目录下的文件实现方法
2018/10/17 Python
对Python中的条件判断、循环以及循环的终止方法详解
2019/02/08 Python
Python3.6实现根据电影名称(支持电视剧名称),获取下载链接的方法
2019/08/26 Python
python获取命令行参数实例方法讲解
2020/11/02 Python
pytorch下的unsqueeze和squeeze的用法说明
2021/02/06 Python
可自定义箭头样式的CSS3气泡提示框
2016/03/16 HTML / CSS
CSS3实现可爱的小黄人动画
2016/07/11 HTML / CSS
科尔士百货公司官网:Kohl’s
2016/07/11 全球购物
爱游人:Travelliker
2017/09/05 全球购物
数控机械专业个人的自我评价
2014/01/02 职场文书
六一亲子活动总结
2014/07/01 职场文书
Python操作CSV格式文件的方法大全
2021/07/15 Python
Nginx动静分离配置实现与说明
2022/04/07 Servers