pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python随机生成指定长度密码的方法
Apr 04 Python
python 性能优化方法小结
Mar 31 Python
python中defaultdict的用法详解
Jun 07 Python
Python实现带参数与不带参数的多重继承示例
Jan 30 Python
windows环境下tensorflow安装过程详解
Mar 30 Python
Python Json模块中dumps、loads、dump、load函数介绍
May 15 Python
使用Python opencv实现视频与图片的相互转换
Jul 08 Python
Python笔试面试题小结
Sep 07 Python
如何在python开发工具PyCharm中搭建QtPy环境(教程详解)
Feb 04 Python
python把一个字符串切开的实例方法
Sep 27 Python
Django2.1.7 查询数据返回json格式的实现
Dec 29 Python
只需要这一行代码就能让python计算速度提高十倍
May 24 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
php+AJAX传送中文会导致乱码的问题的解决方法
2008/09/08 PHP
php递归列出所有文件和目录的代码
2008/09/10 PHP
php 模拟GMAIL,HOTMAIL(MSN),YAHOO,163,126邮箱登录的详细介绍
2013/06/18 PHP
提高PHP性能的编码技巧以及性能优化详细解析
2013/08/24 PHP
php添加数据到xml文件的简单例子
2016/09/08 PHP
使用javascript控制cookie显示和隐藏背景图
2014/02/12 Javascript
jQuery常用操作方法及常用函数总结
2014/06/19 Javascript
jQuery实现友好的轮播图片特效
2015/01/12 Javascript
jQuery实现类似老虎机滚动抽奖效果
2015/08/06 Javascript
jQuery插件HighCharts绘制2D柱状图、折线图和饼图的组合图效果示例【附demo源码下载】
2017/03/09 Javascript
详解AngularJs路由之Ui-router-resolve(预加载)
2017/06/13 Javascript
vue获取input输入值的问题解决办法
2017/10/17 Javascript
jQuery滚动条美化插件nicescroll简单用法示例
2018/04/18 jQuery
vue实现滑动切换效果(仅在手机模式下可用)
2020/06/29 Javascript
js实现石头剪刀布游戏
2020/10/11 Javascript
JavaScript 防盗链的原理以及破解方法
2020/12/29 Javascript
2款Python内存检测工具介绍和使用方法
2014/06/01 Python
Django1.7+python 2.78+pycharm配置mysql数据库教程
2014/11/18 Python
把MySQL表结构映射为Python中的对象的教程
2015/04/07 Python
Python引用传值概念与用法实例小结
2017/10/07 Python
python中csv文件的若干读写方法小结
2018/07/04 Python
python针对不定分隔符切割提取字符串的方法
2018/10/26 Python
Python3实现腾讯云OCR识别
2018/11/27 Python
PyQt打开保存对话框的方法和使用详解
2019/02/27 Python
详解【python】str与json类型转换
2019/04/29 Python
pandas数据筛选和csv操作的实现方法
2019/07/02 Python
Django中提供的6种缓存方式详解
2019/08/05 Python
Python 线程池用法简单示例
2019/10/02 Python
python Opencv计算图像相似度过程解析
2019/12/03 Python
Python连接字符串过程详解
2020/01/06 Python
python3 deque 双向队列创建与使用方法分析
2020/03/24 Python
俄罗斯奢侈品牌衣服、鞋子和配饰的在线商店:INTERMODA
2020/07/17 全球购物
网络维护管理员的自我评价分享
2013/11/11 职场文书
求职信的七个关键技巧
2014/02/05 职场文书
综合素质评价个性发展自我评价
2015/03/06 职场文书
给老婆的检讨书(搞笑版)
2015/05/06 职场文书