pytorch 计算Parameter和FLOP的操作


Posted in Python onMarch 04, 2021

深度学习中,模型训练完后,查看模型的参数量和浮点计算量,在此记录下:

1 THOP

在pytorch中有现成的包thop用于计算参数数量和FLOP,首先安装thop:

pip install thop

注意安装thop时可能出现如下错误:

pytorch 计算Parameter和FLOP的操作

解决方法:

pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git # 下载源码安装

使用方法如下:

from torchvision.models import resnet50 # 引入ResNet50模型
from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224)) # profile(模型,输入数据)

对于自己构建的函数也一样,例如shuffleNetV2

from thop import profile
  from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
  import torch 
  
  model_shuffle = shufflenetv2(width_mult=0.5)
  model = torch.nn.DataParallel(model_shuffle)  # 调用shufflenet2 模型,该模型为自己定义的
  flop, para = profile(model, input_size=(1, 3, 224, 224),) 
  print("%.2fM" % (flop/1e6), "%.2fM" % (para/1e6))

更多细节,可参考thop GitHub链接: https://github.com/Lyken17/pytorch-OpCounter

2 计算参数

pytorch本身带有计算参数的方法

from thop import profile
  from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
  import torch 
  
  model_shuffle = shufflenetv2(width_mult=0.5)
  model = torch.nn.DataParallel(model_shuffle)
  total = sum([param.nelement() for param in model.parameters()])
  print("Number of parameter: %.2fM" % (total / 1e6))

补充:pytorch: 计算网络模型的计算量(FLOPs)和参数量(Params)

计算量:

FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。

参数量:

Params,是指网络模型中需要训练的参数总数。

第一步:安装模块(thop)

pip install thop

第二步:计算

import torch
from thop import profile
net = Model() # 定义好的网络模型
input = torch.randn(1, 3, 112, 112)
flops, params = profile(net, (inputs,))
print('flops: ', flops, 'params: ', params)

注意:

输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数

profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python3实现TCP协议的简单服务器和客户端案例(分享)
Jun 14 Python
JSONLINT:python的json数据验证库实例解析
Nov 28 Python
Python实现将MySQL数据库表中的数据导出生成csv格式文件的方法
Jan 11 Python
python 利用栈和队列模拟递归的过程
May 29 Python
在Mac上删除自己安装的Python方法
Oct 29 Python
浅谈python str.format与制表符\t关于中文对齐的细节问题
Jan 14 Python
python实现网页自动签到功能
Jan 21 Python
python实现一个函数版的名片管理系统过程解析
Aug 27 Python
python使用beautifulsoup4爬取酷狗音乐代码实例
Dec 04 Python
Python原始套接字编程实例解析
Jan 29 Python
使用python实现下载我们想听的歌曲,速度超快
Jul 09 Python
Python 如何创建一个线程池
Jul 28 Python
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
Mar 04 #Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 #Python
python 求两个向量的顺时针夹角操作
Mar 04 #Python
python 制作磁力搜索工具
Mar 04 #Python
python抢购软件/插件/脚本附完整源码
Mar 04 #Python
Python 求向量的余弦值操作
Mar 04 #Python
django使用多个数据库的方法实例
Mar 04 #Python
You might like
PHP脚本的10个技巧(2)
2006/10/09 PHP
使用PHP curl模拟浏览器抓取网站信息
2013/10/28 PHP
php实现session自定义会话处理器的方法
2015/01/27 PHP
php算法实例分享
2015/07/14 PHP
php通过淘宝API查询IP地址归属等信息
2015/12/25 PHP
Android App中DrawerLayout抽屉效果的菜单编写实例
2016/03/21 PHP
PHP实现一个多功能购物网站的案例
2017/09/13 PHP
仿迅雷焦点广告效果(JQuery版)
2008/11/19 Javascript
用JavaScript页面不刷新时全选择,全删除(GridView)
2009/04/14 Javascript
angularjs的一些优化小技巧
2014/12/06 Javascript
JQuery中基础过滤选择器用法实例分析
2015/05/18 Javascript
使用node+vue.js实现SPA应用
2016/01/28 Javascript
使用JS轻松实现ionic调用键盘搜索功能(超实用)
2016/09/06 Javascript
原生JS版和jquery版实现checkbox的全选/全不选/点选/行内点选(Mr.Think)
2016/10/29 Javascript
jQuery Pagination分页插件使用方法详解
2017/02/28 Javascript
详解webpack+es6+angular1.x项目构建
2017/05/02 Javascript
vuex 的简单使用
2018/03/22 Javascript
JS弹窗 JS弹出DIV并使整个页面背景变暗功能的实现代码
2018/04/21 Javascript
vue 开发企业微信整合案例分析
2019/12/02 Javascript
JavaScript Window浏览器对象模型原理解析
2020/05/30 Javascript
[56:20]LGD vs VP Supermajor 败者组决赛 BO3 第三场 6.10
2018/07/04 DOTA
python的构建工具setup.py的方法使用示例
2017/10/23 Python
Python基于Flask框架配置依赖包信息的项目迁移部署
2018/03/02 Python
django 2.2和mysql使用的常见问题
2019/07/18 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
2019/08/07 Python
更新pip3与pyttsx3文字语音转换的实现方法
2019/08/08 Python
python 变量初始化空列表的例子
2019/11/28 Python
英国第一家领先的在线处方眼镜零售商:Glasses Direct
2018/02/23 全球购物
国际象棋商店:The Chess Store
2018/07/09 全球购物
新加坡最佳婴儿用品店:Mamahood.com.sg
2018/08/26 全球购物
Java中compareTo和compare的区别
2016/04/12 面试题
网络技术支持面试题
2013/04/22 面试题
小学毕业典礼主持词
2014/03/27 职场文书
2014年招生工作总结
2014/11/26 职场文书
2014年预算员工作总结
2014/12/05 职场文书
MySQL表字段时间设置默认值
2021/05/13 MySQL