pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
使用Python的Treq on Twisted来进行HTTP压力测试
Apr 16 Python
python3 与python2 异常处理的区别与联系
Jun 19 Python
Python冲顶大会 快来答题!
Jan 17 Python
对Python 两大环境管理神器 pyenv 和 virtualenv详解
Dec 31 Python
详解从Django Rest Framework响应中删除空字段
Jan 11 Python
详解用pyecharts Geo实现动态数据热力图城市找不到问题解决
Jun 26 Python
python开头的coding设置方法
Aug 08 Python
Python Subprocess模块原理及实例
Aug 26 Python
Django项目后台不挂断运行的方法
Aug 31 Python
详解基于Jupyter notebooks采用sklearn库实现多元回归方程编程
Mar 25 Python
浅谈django 重载str 方法
May 19 Python
python3实现Dijkstra算法最短路径的实现
May 12 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
php self,$this,const,static,->的使用
2009/10/22 PHP
PHP抽象类 介绍
2012/06/13 PHP
如何在smarty中增加类似foreach的功能自动加载数据
2013/06/26 PHP
Zend Framework教程之Autoloading用法详解
2016/03/08 PHP
php递归函数怎么用才有效
2018/02/24 PHP
ASP Json Parser修正版
2009/12/06 Javascript
Juqery Html(),append()等方法的Bug解决方法
2010/12/13 Javascript
让人期待的2011年度最佳 jQuery 插件分享
2012/03/16 Javascript
javascript使用中为什么10..toString()正常而10.toString()出错呢
2013/01/11 Javascript
js原生appendChild的bug解决心得分享
2013/07/01 Javascript
jQuery实现ichat在线客服插件
2014/12/29 Javascript
jqueryUI里拖拽排序示例分析
2015/02/26 Javascript
javascript从定义到执行 你不知道的那些事
2016/01/04 Javascript
JS简单生成两个数字之间随机数的方法
2016/08/03 Javascript
详解Angular的数据显示优化处理
2016/12/26 Javascript
百度地图JavascriptApi Marker平滑移动及车头指向行径方向
2017/03/13 Javascript
vue双向绑定简要分析
2017/03/23 Javascript
详解基于node的前端项目编译时内存溢出问题
2017/08/01 Javascript
基于angular-utils-ui-breadcrumbs使用心得(分享)
2017/11/03 Javascript
浅谈FastClick 填坑及源码解析
2018/03/02 Javascript
node.js到底要不要加分号浅析
2018/07/11 Javascript
vue拖拽排序插件vuedraggable使用方法详解
2020/08/21 Javascript
vue使用Sass时报错问题的解决方法
2020/10/14 Javascript
解决Python中由于logging模块误用导致的内存泄露
2015/04/23 Python
Python中的ctime()方法使用教程
2015/05/22 Python
Python 模拟购物车的实例讲解
2017/09/11 Python
Python基于动态规划算法解决01背包问题实例
2017/12/06 Python
django ajax json的实例代码
2018/05/29 Python
Python3使用Matplotlib 绘制精美的数学函数图形
2019/04/11 Python
毕业生求职推荐信
2013/11/04 职场文书
优秀学生干部推荐材料
2014/02/03 职场文书
2014年小学国庆节活动方案
2014/09/16 职场文书
自主招生推荐信怎么写
2015/03/26 职场文书
行政处罚事先告知书
2015/07/01 职场文书
推广普通话宣传标语口号
2015/12/26 职场文书
Python 快速验证代理IP是否有效的方法实现
2021/07/15 Python