pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python登陆asp网站页面的实现代码
Jan 14 Python
Python实现从订阅源下载图片的方法
Mar 11 Python
python读写二进制文件的方法
May 09 Python
Python批量按比例缩小图片脚本分享
May 21 Python
Python OS模块常用函数说明
May 23 Python
Python中map和列表推导效率比较实例分析
Jun 17 Python
JSON Web Tokens的实现原理
Apr 02 Python
利用Django内置的认证视图实现用户密码重置功能详解
Nov 24 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
解决Python的str强转int时遇到的问题
Apr 09 Python
python3使用matplotlib绘制条形图
Mar 25 Python
Django框架实现的普通登录案例【使用POST方法】
May 15 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
PHP fastcgi模式上传大文件(大约有300多K)报错
2014/09/28 PHP
php树型类实例
2014/12/05 PHP
php遍历树的常用方法汇总
2015/06/18 PHP
PHP 返回13位时间戳的实现代码
2016/05/13 PHP
PHP实现将MySQL重复ID二维数组重组为三维数组的方法
2016/08/01 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
Javascript 更新 JavaScript 数组的 uniq 方法
2008/01/23 Javascript
Jquery 获取checkbox的checked问题
2011/11/16 Javascript
检测input每次的输入是否合法遇到汉字输入就有问题
2012/05/23 Javascript
如何使用jQUery获取选中radio对应的值(一句代码)
2013/06/03 Javascript
jQuery中unwrap()方法用法实例
2015/01/16 Javascript
手机开发必备技巧:javascript及CSS功能代码分享
2015/05/25 Javascript
Javascript的表单验证-初识正则表达式
2016/03/18 Javascript
AngularJS ng-blur 指令详解及简单实例
2016/07/30 Javascript
JavaScript中cookie工具函数封装的示例代码
2016/10/11 Javascript
javascript创建对象的3种方法
2016/11/02 Javascript
node实现分片下载的示例代码
2018/10/17 Javascript
VUE中使用MUI方法
2019/02/12 Javascript
Node.JS发送http请求批量检查文件中的网页地址、服务是否有效可用
2019/11/20 Javascript
JavaScript中如何对多维数组(矩阵)去重的实现
2019/12/04 Javascript
Python和php通信乱码问题解决方法
2014/04/15 Python
跟老齐学Python之从格式化表达式到方法
2014/09/28 Python
Python远程桌面协议RDPY安装使用介绍
2015/04/15 Python
使用Python脚本实现批量网站存活检测遇到问题及解决方法
2016/10/11 Python
python中pandas.DataFrame对行与列求和及添加新行与列示例
2017/03/12 Python
使用python和Django完成博客数据库的迁移方法
2018/01/05 Python
python中强大的format函数实例详解
2018/12/05 Python
Django实现单用户登录的方法示例
2019/03/28 Python
Python实现12306火车票抢票系统
2019/07/04 Python
django连接oracle时setting 配置方法
2019/08/29 Python
PyCharm最新激活码(2020/10/27全网最新)
2020/10/27 Python
HTML5拖拽文件上传的示例代码
2021/03/04 HTML / CSS
什么是跨站脚本攻击
2014/12/11 面试题
教师工作自我鉴定范文
2014/09/14 职场文书
小学语文的各类谚语(70首)
2019/08/15 职场文书
Android Gradle 插件自定义Plugin实现注意事项
2022/06/16 Java/Android