pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python发送arp欺骗攻击代码分析
Jan 16 Python
Python使用百度API上传文件到百度网盘代码分享
Nov 08 Python
Python实现去除代码前行号的方法
Mar 10 Python
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
Python编程求解二叉树中和为某一值的路径代码示例
Jan 04 Python
详解Python 切片语法
Jun 10 Python
python网络编程之多线程同时接受和发送
Sep 03 Python
使用virtualenv创建Python环境及PyQT5环境配置的方法
Sep 10 Python
python 函数的缺省参数使用注意事项分析
Sep 17 Python
深入了解Python 变量作用域
Jul 24 Python
用python实现前向分词最大匹配算法的示例代码
Aug 06 Python
Python中的面向接口编程示例详解
Jan 17 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
PHP与javascript的两种交互方式
2006/10/09 PHP
PHP var_dump遍历对象属性的函数与应用代码
2010/06/04 PHP
php中global和$GLOBALS[]的分析之一
2012/02/02 PHP
php中常用的预定义变量小结
2012/05/09 PHP
给WordPress中的留言加上楼层号的PHP代码实例
2015/12/14 PHP
PHP入门教程之会话控制技巧(cookie与session)
2016/09/11 PHP
js文字滚动停顿效果代码
2008/06/28 Javascript
用jQuery扩展自写的 UI导航
2010/01/13 Javascript
javascript 数据类型转换(parseInt,parseFloat)
2010/07/20 Javascript
仿百度输入框智能提示的js代码
2013/08/22 Javascript
jquery foreach使用示例
2013/09/12 Javascript
jquery删除指定的html标签并保留标签内文本内容的方法
2014/04/02 Javascript
jquery防止重复执行动画避免页面混乱
2014/04/22 Javascript
javascript与jquery中的this关键字用法实例分析
2015/12/24 Javascript
对象转换为原始值的实现方法
2016/06/06 Javascript
JS实现的驼峰式和连字符式转换功能分析
2016/12/21 Javascript
nodejs中sleep功能实现暂停几秒的方法
2017/07/12 NodeJs
React中使用collections时key的重要性详解
2017/08/07 Javascript
jQuery序列化form表单数据为JSON对象的实现方法
2018/09/20 jQuery
Vue 之孙组件向爷组件通信的实现
2019/04/23 Javascript
javascript实现支付宝滑块验证码效果
2020/07/24 Javascript
在vue-cli3.0 中使用预处理器 (Sass/Less/Stylus) 配置全局变量操作
2020/08/10 Javascript
Vue插槽_特殊特性slot,slot-scope与指令v-slot说明
2020/09/04 Javascript
解决idea开发遇到javascript动态添加html元素时中文乱码的问题
2020/09/29 Javascript
通过5个知识点轻松搞定Python的作用域
2016/09/09 Python
基于Django用户认证系统详解
2018/02/21 Python
简单了解python列表和元组的区别
2020/05/14 Python
python Paramiko使用示例
2020/09/21 Python
全方位了解CSS3的Regions扩展
2015/08/07 HTML / CSS
Timberland俄罗斯官方网上商店:全球领先的户外品牌
2020/03/15 全球购物
企业法人代表证明书
2014/09/27 职场文书
党组织领导班子整改方案
2014/10/25 职场文书
自我工作评价范文
2015/03/06 职场文书
工程技术员岗位职责
2015/04/11 职场文书
2015年度团总支工作总结
2015/04/23 职场文书
车辆挂靠协议书
2016/03/23 职场文书