pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python三元运算符实现方法
Dec 17 Python
Python实现压缩与解压gzip大文件的方法
Sep 18 Python
Python 递归函数详解及实例
Dec 27 Python
python subprocess 杀掉全部派生的子进程方法
Jan 16 Python
PyQt5每天必学之滑块控件QSlider
Apr 20 Python
python json.loads兼容单引号数据的方法
Dec 19 Python
Python tkinter三种布局实例详解
Jan 06 Python
Python SMTP发送电子邮件的示例
Sep 23 Python
解决Pymongo insert时会自动添加_id的问题
Dec 05 Python
python的dict判断key是否存在的方法
Dec 09 Python
ASP.NET Core中的配置详解
Feb 05 Python
python识别围棋定位棋盘位置
Jul 26 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
source.php查看源文件
2006/12/09 PHP
9个PHP开发常用功能函数小结
2011/07/15 PHP
thinkphp备份数据库的方法分享
2015/01/04 PHP
PHP实现递归无限级分类
2015/10/22 PHP
PHP新建类问题分析及解决思路
2015/11/19 PHP
优化WordPress的Google字体以加速国内服务器上的运行
2015/11/24 PHP
详解PHP防止直接访问.php 文件的实现方法
2017/07/28 PHP
PHP获取当前系统时间的方法小结
2018/10/03 PHP
php常用经典函数集锦【数组、字符串、栈、队列、排序等】
2019/08/23 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
Tips 带三角可关闭的文字提示
2010/10/06 Javascript
使用GruntJS链接与压缩多个JavaScript文件过程详解
2013/08/02 Javascript
js全选实现和判断是否有复选框选中的方法
2015/02/17 Javascript
jQuery实现图片上传和裁剪插件Croppie
2015/11/29 Javascript
javascript的理解及经典案例分析
2016/05/20 Javascript
微信小程序 canvas API详解及实例代码
2016/10/08 Javascript
angularjs封装$http为factory的方法
2017/05/18 Javascript
Javascript实现运算符重载详解
2018/04/07 Javascript
微信小程序实现分享到朋友圈功能
2018/07/19 Javascript
vue+vue-router转场动画的实例代码
2018/09/01 Javascript
Python实现SVN的目录周期性备份实例
2015/07/17 Python
python django事务transaction源码分析详解
2017/03/17 Python
Matplotlib 生成不同大小的subplots实例
2018/05/25 Python
Python subprocess库的使用详解
2018/10/26 Python
pandas DataFrame 交集并集补集的实现
2019/06/24 Python
货代行业个人求职简历的自我评价
2013/10/22 职场文书
医学专业五年以上个人求职信
2013/12/03 职场文书
四个太阳教学反思
2014/02/01 职场文书
函授大学生自我鉴定
2014/02/05 职场文书
室内设计专业毕业生求职信
2014/05/02 职场文书
卫生标语大全
2014/06/21 职场文书
学前班语言教学计划
2015/01/20 职场文书
2015年安全工作总结范文
2015/04/02 职场文书
公司年会开场白
2015/06/01 职场文书
遗愿清单观后感
2015/06/09 职场文书
全国劳模先进事迹材料(2016精选版)
2016/02/25 职场文书