pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python pickle类库介绍(对象序列化和反序列化)
Nov 21 Python
在Python3中初学者应会的一些基本的提升效率的小技巧
Mar 31 Python
Python  pip安装lxml出错的问题解决办法
Feb 10 Python
Python编程实现的简单Web服务器示例
Jun 22 Python
Python基于time模块求程序运行时间的方法
Sep 18 Python
python类的方法属性与方法属性的动态绑定代码详解
Dec 27 Python
pytorch 转换矩阵的维数位置方法
Dec 08 Python
python 实现在tkinter中动态显示label图片的方法
Jun 13 Python
pygame实现飞机大战
Mar 11 Python
在jupyter notebook中调用.ipynb文件方式
Apr 14 Python
Python编写nmap扫描工具
Jul 21 Python
OpenCV图像变换之傅里叶变换的一些应用
Jul 26 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
PHP 处理图片的类实现代码
2009/10/23 PHP
如何使用PHP获取指定日期所在月的开始日期与结束日期
2013/08/01 PHP
php导入模块文件分享
2015/03/17 PHP
PHP SPL标准库之SplFixedArray使用实例
2015/05/12 PHP
Zend Framework教程之请求对象的封装Zend_Controller_Request实例详解
2016/03/07 PHP
Joomla简单判断用户是否登录的方法
2016/05/04 PHP
json 定义
2008/06/10 Javascript
IE和firefox浏览器的event事件兼容性汇总
2009/12/06 Javascript
Javascript 检测键盘按键信息及键码值对应介绍
2013/01/03 Javascript
jquery延迟加载外部js实现代码
2013/01/11 Javascript
jquery重复提交请求的原因浅析
2014/05/23 Javascript
jQuery中大家不太了解的几个方法
2015/03/04 Javascript
IE浏览器下PNG相关功能
2015/07/05 Javascript
JS实现带提示的星级评分效果完整实例
2015/10/30 Javascript
JS实现六边形3D拖拽翻转效果的方法
2016/09/11 Javascript
分享bootstrap学习笔记心得(组件及其属性)
2017/01/11 Javascript
利用nodejs监控文件变化并使用sftp上传到服务器
2017/02/18 NodeJs
js 获取图像缩放后的实际宽高,位置等信息
2017/03/07 Javascript
vue2.0全局组件之pdf详解
2017/06/26 Javascript
微信小程序设置全局请求URL及封装wx.request请求操作示例
2019/04/02 Javascript
js实现简单进度条效果
2020/03/25 Javascript
JS错误处理与调试操作实例分析
2020/04/13 Javascript
使用django-suit为django 1.7 admin后台添加模板
2014/11/18 Python
python中nan与inf转为特定数字方法示例
2017/05/11 Python
python urllib urlopen()对象方法/代理的补充说明
2017/06/29 Python
使用EduBlock轻松学习Python编程
2018/10/08 Python
用Python画小女孩放风筝的示例
2019/11/23 Python
Python正则表达式学习小例子
2020/03/03 Python
PyCharm 2020.2下配置Anaconda环境的方法步骤
2020/09/23 Python
HTML5 解决苹果手机不能自动播放音乐问题
2017/12/27 HTML / CSS
学习十八大精神心得体会
2013/12/31 职场文书
普罗米修斯教学反思
2014/02/06 职场文书
酒店端午节促销方案
2014/02/18 职场文书
《找不到快乐的波斯猫》教学反思
2014/02/24 职场文书
汉语言文学专业自荐信
2014/06/11 职场文书
农村婚礼司仪主持词
2015/06/29 职场文书