Pytorch 统计模型参数量的操作 param.numel()


Posted in Python onMay 13, 2021

param.numel()

返回param中元素的数量

Pytorch 统计模型参数量的操作 param.numel()

统计模型参数量

num_params = sum(param.numel() for param in net.parameters())
print(num_params)

补充:Pytorch 查看模型参数

Pytorch 查看模型参数

查看利用Pytorch搭建模型的参数,直接看程序

import torch
# 引入torch.nn并指定别名
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        # nn.Module子类的函数必须在构造函数中执行父类的构造函数
        super(Net, self).__init__()
        
        # 卷积层 '1'表示输入图片为单通道, '6'表示输出通道数,'3'表示卷积核为3*3
        self.conv1 = nn.Conv2d(1, 6, 3) 
        #线性层,输入1350个特征,输出10个特征
        self.fc1   = nn.Linear(1350, 10)  #这里的1350是如何计算的呢?这就要看后面的forward函数
    #正向传播 
    def forward(self, x): 
        print(x.size()) # 结果:[1, 1, 32, 32]
        # 卷积 -> 激活 -> 池化 
        x = self.conv1(x) #根据卷积的尺寸计算公式,计算结果是30,具体计算公式后面第二张第四节 卷积神经网络 有详细介绍。
        x = F.relu(x)
        print(x.size()) # 结果:[1, 6, 30, 30]
        x = F.max_pool2d(x, (2, 2)) #我们使用池化层,计算结果是15
        x = F.relu(x)
        print(x.size()) # 结果:[1, 6, 15, 15]
        # reshape,‘-1'表示自适应
        #这里做的就是压扁的操作 就是把后面的[1, 6, 15, 15]压扁,变为 [1, 1350]
        x = x.view(x.size()[0], -1) 
        print(x.size()) # 这里就是fc1层的的输入1350 
        x = self.fc1(x)        
        return x

net = Net()
for parameters in net.parameters():
    print(parameters)

输出为:

Parameter containing:
tensor([[[[-0.0104, -0.0555, 0.1417],
[-0.3281, -0.0367, 0.0208],
[-0.0894, -0.0511, -0.1253]]],


[[[-0.1724, 0.2141, -0.0895],
[ 0.0116, 0.1661, -0.1853],
[-0.1190, 0.1292, -0.2451]]],


[[[ 0.1827, 0.0117, 0.2880],
[ 0.2412, -0.1699, 0.0620],
[ 0.2853, -0.2794, -0.3050]]],


[[[ 0.1930, 0.2687, -0.0728],
[-0.2812, 0.0301, -0.1130],
[-0.2251, -0.3170, 0.0148]]],


[[[-0.2770, 0.2928, -0.0875],
[ 0.0489, -0.2463, -0.1605],
[ 0.1659, -0.1523, 0.1819]]],


[[[ 0.1068, 0.2441, 0.3160],
[ 0.2945, 0.0897, 0.2978],
[ 0.0419, -0.0739, -0.2609]]]])
Parameter containing:
tensor([ 0.0782, 0.2679, -0.2516, -0.2716, -0.0084, 0.1401])
Parameter containing:
tensor([[ 1.8612e-02, 6.5482e-03, 1.6488e-02, ..., -1.3283e-02,
1.8715e-02, 5.4037e-03],
[ 1.8569e-03, 1.8022e-02, -2.3465e-02, ..., 1.6527e-03,
2.0443e-02, -2.2009e-02],
[ 9.9104e-03, 6.6134e-03, -2.7171e-02, ..., -5.7119e-03,
2.4532e-02, 2.2284e-02],
...,
[ 6.9182e-03, 1.7279e-02, -1.7783e-03, ..., 1.9354e-02,
2.1105e-03, 8.6245e-03],
[ 1.6877e-02, -1.2414e-02, 2.2409e-02, ..., -2.0604e-02,
1.3253e-02, -3.6008e-03],
[-2.1598e-02, 2.5892e-02, 1.9372e-02, ..., 1.4159e-02,
7.0983e-03, -2.3713e-02]])
Parameter containing:
tensor(1.00000e-02 *
[ 1.4703, 1.0289, 2.5069, -2.2603, -1.5218, -1.7019, 1.2569,
0.4617, -2.3082, -0.6282])

for name,parameters in net.named_parameters():
    print(name,':',parameters.size())

输出:

conv1.weight : torch.Size([6, 1, 3, 3])
conv1.bias : torch.Size([6])
fc1.weight : torch.Size([10, 1350])
fc1.bias : torch.Size([10])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍Python中的filter和lambda函数的使用
Apr 07 Python
Python中实现最小二乘法思路及实现代码
Jan 04 Python
Python迭代器与生成器用法实例分析
Jul 09 Python
python实现名片管理系统
Nov 29 Python
Python2和Python3的共存和切换使用
Apr 12 Python
Python计算一个点到所有点的欧式距离实现方法
Jul 04 Python
pandas中的series数据类型详解
Jul 06 Python
Python中正反斜杠(‘/’和‘\’)的意义与用法
Aug 12 Python
解决pandas展示数据输出时列名不能对齐的问题
Nov 18 Python
torch 中各种图像格式转换的实现方法
Dec 26 Python
Python: glob匹配文件的操作
Dec 11 Python
matplotlib之pyplot模块实现添加子图subplot的使用
Apr 25 Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
用python删除文件夹中的重复图片(图片去重)
May 12 #Python
Pyhton模块和包相关知识总结
You might like
PHP 高手之路(三)
2006/10/09 PHP
利用php获取服务器时间的实现代码
2013/06/07 PHP
php中preg_replace_callback函数简单用法示例
2016/07/21 PHP
php两点地理坐标距离的计算方法
2018/12/29 PHP
TP3.2.3框架文件上传操作实例详解
2020/01/23 PHP
在jquery中的ajax方法怎样通过JSONP进行远程调用
2014/04/04 Javascript
JQuery鼠标移到小图显示大图效果的方法
2015/06/10 Javascript
针对JavaScript中this指向的简单理解
2016/08/26 Javascript
ionic隐藏tabs的方法
2016/08/29 Javascript
js实现简单的网页换肤效果
2017/01/18 Javascript
js获取地址栏中传递的参数(两种方法)
2017/02/08 Javascript
js编写简单的聊天室功能
2017/08/17 Javascript
VUE中的无限循环代码解析
2017/09/22 Javascript
AngularJS动态添加数据并删除的实例
2018/02/27 Javascript
Bootstrap 模态框自定义点击和关闭事件详解
2018/08/10 Javascript
js中调用微信的扫描二维码功能的实现代码
2020/04/11 Javascript
深入Python解释器理解Python中的字节码
2015/04/01 Python
Django代码性能优化与Pycharm Profile使用详解
2018/08/26 Python
对python中的six.moves模块的下载函数urlretrieve详解
2018/12/19 Python
python将pandas datarame保存为txt文件的实例
2019/02/12 Python
Python 使用list和tuple+条件判断详解
2019/07/30 Python
python实现对列表中的元素进行倒序打印
2019/11/23 Python
python线程定时器Timer实现原理解析
2019/11/30 Python
python通过文本在一个图中画多条线的实例
2020/02/21 Python
python 深度学习中的4种激活函数
2020/09/18 Python
matplotlib之pyplot模块之标题(title()和suptitle())
2021/02/22 Python
CSS3实现图片抽屉式效果的示例代码
2019/11/06 HTML / CSS
线程问题:wait()方法是定义在哪个类里面
2015/07/07 面试题
大学生学习自我评价
2014/01/13 职场文书
工程学毕业生自荐信
2014/06/14 职场文书
网络营销计划书
2015/01/17 职场文书
大明湖导游词
2015/02/03 职场文书
幼师辞职信怎么写
2015/02/27 职场文书
大学生心理健康活动总结
2015/05/08 职场文书
迁徙的鸟观后感
2015/06/09 职场文书
基于Python编写一个监控CPU的应用系统
2022/06/25 Python