pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
可用于监控 mysql Master Slave 状态的python代码
Feb 10 Python
Python脚本在Appium库上对移动应用实现自动化测试
Apr 17 Python
探究Python多进程编程下线程之间变量的共享问题
May 05 Python
Python functools模块学习总结
May 09 Python
Python使用pygame模块编写俄罗斯方块游戏的代码实例
Dec 08 Python
浅谈Python类的__getitem__和__setitem__特殊方法
Dec 25 Python
django框架如何集成celery进行开发
May 24 Python
Python 反转字符串(reverse)的方法小结
Feb 20 Python
浅谈Python中的作用域规则和闭包
Mar 20 Python
基于多进程中APScheduler重复运行的解决方法
Jul 22 Python
Python使用扩展库pywin32实现批量文档打印实例
Apr 09 Python
Python 数据结构之十大经典排序算法一文通关
Oct 16 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
MySql 按时间段查询数据方法(实例说明)
2008/11/02 PHP
PHP遍历数组的方法汇总
2015/04/30 PHP
php  单例模式详细介绍及实现源码
2016/11/05 PHP
利用Laravel事件系统如何实现登录日志的记录详解
2017/05/20 PHP
phpcmsv9.0任意文件上传漏洞解析
2020/10/20 PHP
JavaScript入门学习书籍推荐
2008/06/12 Javascript
javascript showModalDialog,open取得父窗口的方法
2010/03/10 Javascript
JavaScript控制各种浏览器全屏模式的方法、属性和事件介绍
2014/04/03 Javascript
jQuery实现tag便签去重效果的方法
2015/01/20 Javascript
javascript实现简易计算器的代码
2016/05/31 Javascript
预防网页挂马的方法总结
2016/11/03 Javascript
vue.js组件之间传递数据的方法
2017/07/10 Javascript
浅析JS中常用类型转换及运算符表达式
2017/07/23 Javascript
详解a++和++a的区别
2017/08/30 Javascript
vue 项目 iOS WKWebView 加载
2019/04/17 Javascript
vue自定义正在加载动画的例子
2019/11/14 Javascript
js实现鼠标切换图片(无定时器)
2021/01/27 Javascript
[00:17]游戏风云独家报道:DD赛后说出数字秘密 吓死你们啊!
2014/07/13 DOTA
[03:37]2016完美“圣”典 风云人物:Mikasa专访
2016/12/07 DOTA
Python实现的Kmeans++算法实例
2014/04/26 Python
使用Python微信库itchat获得好友和群组已撤回的消息
2018/06/24 Python
Python实现的读取/更改/写入xml文件操作示例
2018/08/30 Python
Python爬虫实现验证码登录代码实例
2019/05/10 Python
Django框架orM与自定义SQL语句混合事务控制操作
2019/06/27 Python
IronPython连接MySQL的方法步骤
2019/12/27 Python
解决Python3.8运行tornado项目报NotImplementedError错误
2020/09/02 Python
印尼太阳百货公司网站:Matahari
2018/02/04 全球购物
英国综合网上购物商城:The Hut
2018/07/03 全球购物
土木工程毕业生自荐信
2013/09/21 职场文书
应届生骨科医生求职信
2013/10/31 职场文书
怎样客观的做好自我评价
2013/12/28 职场文书
人事经理岗位职责
2014/04/28 职场文书
公司2014年度工作总结
2014/12/10 职场文书
5.12护士节活动总结
2015/02/10 职场文书
八年级英语教学反思
2016/02/15 职场文书
进行数据处理的6个 Python 代码块分享
2022/04/06 Python