pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python实现的HTTP并发测试完整示例
Apr 23 Python
Python修改MP3文件的方法
Jun 15 Python
在Django的视图(View)外使用Session的方法
Jul 23 Python
Python中time模块与datetime模块在使用中的不同之处
Nov 24 Python
简单谈谈python中的lambda表达式
Jan 19 Python
Python输出由1,2,3,4组成的互不相同且无重复的三位数
Feb 01 Python
python实现的发邮件功能示例
Sep 11 Python
python实现画循环圆
Nov 23 Python
python随机模块random使用方法详解
Feb 14 Python
python如何判断IP地址合法性
Apr 05 Python
Django用户登录与注册系统的实现示例
Jun 03 Python
Python数据可视化图实现过程详解
Jun 12 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
windows下升级PHP到5.3.3的过程及注意事项
2010/10/12 PHP
Laravel 5框架学习之表单
2015/04/08 PHP
PHP实现下载远程图片保存到本地的方法
2017/06/19 PHP
基于PHP实现栈数据结构和括号匹配算法示例
2017/08/10 PHP
js压缩利器
2007/02/20 Javascript
JavaScript 特殊字符
2007/04/05 Javascript
清除网页历史记录,屏蔽后退按钮!
2008/12/22 Javascript
js/jquery判断浏览器类型的方法小结
2015/05/12 Javascript
js实现Select头像选择实时预览代码
2015/08/17 Javascript
详解JavaScript对象类型
2016/06/16 Javascript
基于Bootstrap的UI扩展 StyleBootstrap
2016/06/17 Javascript
详解Angular 自定义结构指令
2017/06/21 Javascript
[js高手之路]从原型链开始图解继承到组合继承的产生详解
2017/08/28 Javascript
NodeJS父进程与子进程资源共享原理与实现方法
2018/03/16 NodeJs
boostrap模态框二次弹出清空原有内容的方法
2018/08/10 Javascript
webpack4.0+vue2.0利用批处理生成前端单页或多页应用的方法
2019/06/28 Javascript
使用easyui从servlet传递json数据到前端页面的两种方法
2019/09/05 Javascript
python使用循环实现批量创建文件夹示例
2014/03/25 Python
Python中声明只包含一个元素的元组数据方法
2014/08/25 Python
一百多行python代码实现抢票助手
2018/09/25 Python
Python之两种模式的生产者消费者模型详解
2018/10/26 Python
Django如何自定义model创建数据库索引的顺序
2019/06/20 Python
django实现用户注册实例讲解
2019/10/30 Python
通过实例了解Python str()和repr()的区别
2020/01/17 Python
Python调用Windows命令打印文件
2020/02/07 Python
Clarks英国官方网站:全球领军鞋履品牌
2016/11/26 全球购物
英国的领先快速时尚零售商:In The Style
2019/03/25 全球购物
意大利奢侈品综合电商网站:MODES
2019/12/14 全球购物
幼儿教师自我鉴定
2013/11/02 职场文书
大学生活学习的自我评价
2013/12/03 职场文书
社区敬老月活动实施方案
2014/02/17 职场文书
关于运动会的口号
2014/06/07 职场文书
助人为乐好少年事迹材料
2014/08/18 职场文书
个人查摆剖析材料
2014/10/04 职场文书
vue基于Teleport实现Modal组件
2021/05/31 Vue.js
Ruby处理YAML和json数据
2022/04/18 Ruby