pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python学习笔记:字典的使用示例详解
Jun 13 Python
简单介绍Python中的try和finally和with方法
May 05 Python
python3中str(字符串)的使用教程
Mar 23 Python
python opencv 二值化 计算白色像素点的实例
Jul 03 Python
django项目用higcharts统计最近七天文章点击量
Aug 17 Python
python Opencv计算图像相似度过程解析
Dec 03 Python
Django 限制访问频率的思路详解
Dec 24 Python
python 实现简单的FTP程序
Dec 27 Python
Python图像处理库PIL的ImageFont模块使用介绍
Feb 26 Python
Python绘制全球疫情变化地图的实例代码
Apr 20 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
Jun 18 Python
详解scrapy内置中间件的顺序
Sep 28 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
PHP 远程关机实现代码
2009/11/10 PHP
Yii2 assets清除缓存的方法
2016/05/16 PHP
PHP实现的CURL非阻塞调用类
2018/07/26 PHP
一文掌握PHP Xdebug 本地与远程调试(小结)
2019/04/23 PHP
javascript 定义初始化数组函数
2009/09/07 Javascript
js如何取消事件冒泡
2013/09/23 Javascript
jQuery下的动画处理总结
2013/10/10 Javascript
JavaScript获取网页表单action属性的方法
2015/04/02 Javascript
JQuery radio(单选按钮)操作方法汇总
2015/04/15 Javascript
JavaScript中string对象
2015/06/12 Javascript
理解javascript对象继承
2016/04/17 Javascript
AngularJS中的过滤器filter用法完全解析
2016/04/22 Javascript
Bootstrap轮播插件中图片变形的终极解决方案 使用jqthumb.js
2016/07/10 Javascript
EasyUI加载完Html内容样式渲染完成后显示
2016/07/25 Javascript
js 简易版滚动条实例(适用于移动端H5开发)
2017/06/26 Javascript
vue2.0 + element UI 中 el-table 数据导出Excel的方法
2018/03/02 Javascript
vue.js select下拉框绑定和取值方法
2018/03/03 Javascript
vue.js使用watch监听路由变化的方法
2018/07/08 Javascript
在Vuex使用dispatch和commit来调用mutations的区别详解
2018/09/18 Javascript
微信小程序websocket聊天室的实现示例代码
2019/02/12 Javascript
vue2配置scss的方法步骤
2019/06/06 Javascript
jQuery实现消息弹出框效果
2019/12/10 jQuery
Python使用正则匹配实现抓图代码分享
2015/04/02 Python
python实现向ppt文件里插入新幻灯片页面的方法
2015/04/28 Python
Python 2.x如何设置命令执行的超时时间实例
2017/10/19 Python
Python基于滑动平均思想实现缺失数据填充的方法
2019/02/21 Python
python脚本当作Linux中的服务启动实现方法
2019/06/28 Python
OpenCV Python实现拼图小游戏
2020/03/23 Python
使用phonegap检测网络状态的方法
2017/03/30 HTML / CSS
Tory Burch德国官网:美国时尚生活品牌
2018/01/03 全球购物
测量实习生自我鉴定
2013/09/19 职场文书
房屋买卖协议书
2014/04/10 职场文书
开会通知
2015/04/20 职场文书
同学聚会开幕词
2019/04/02 职场文书
JavaScript数组reduce()方法的语法与实例解析
2021/07/07 Javascript
JavaScript实现九宫格拖拽效果
2022/06/28 Javascript