pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python实现批量下载文件
May 17 Python
Python实现监控程序执行时间并将其写入日志的方法
Jun 30 Python
Python及Django框架生成二维码的方法分析
Jan 31 Python
python2 与 python3 实现共存的方法
Jul 12 Python
Python魔法方法功能与用法简介
Apr 04 Python
python制作英语翻译小工具代码实例
Sep 09 Python
Python实现中值滤波去噪方式
Dec 18 Python
Python文件名匹配与文件复制的实现
Dec 11 Python
深入理解python多线程编程
Apr 18 Python
python flask框架快速入门
May 14 Python
利用Python判断你的密码难度等级
Jun 02 Python
python通过函数名调用函数的几种方法总结
Jun 07 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
php chr() ord()中文截取乱码问题解决方法
2008/09/08 PHP
CI(CodeIgniter)框架配置
2014/06/10 PHP
Laravel中使用FormRequest进行表单验证方法及问题汇总
2016/06/19 PHP
thinkPHP商城公告功能开发问题分析
2016/12/01 PHP
Zend Framework入门应用实例详解
2016/12/11 PHP
yii2.0整合阿里云oss删除单个文件的方法
2017/09/19 PHP
kmock javascript 单元测试代码
2011/02/06 Javascript
基于jquery的3d效果实现代码
2011/03/23 Javascript
js冒泡法和数组转换成字符串示例代码
2013/08/14 Javascript
Js操作Select大全(取值、设置选中等等)
2013/10/29 Javascript
js触发select onchange事件的小技巧
2014/08/05 Javascript
基于javascript、ajax、memcache和PHP实现的简易在线聊天室
2015/02/03 Javascript
js操作数组函数实例小结
2015/12/10 Javascript
javascript跨域请求包装函数与用法示例
2016/11/03 Javascript
vue组件学习教程
2017/09/09 Javascript
vue检测对象和数组的变化分析
2018/06/30 Javascript
Jquery的Ajax技术使用方法
2019/01/21 jQuery
微信小程序实现上传多个文件 超过10个
2020/03/30 Javascript
js实现百度淘宝搜索功能
2020/02/17 Javascript
Vue 禁用浏览器的前进后退操作
2020/09/04 Javascript
jquery实现鼠标悬浮弹出气泡提示框
2020/12/23 jQuery
pycharm 使用心得(九)解决No Python interpreter selected的问题
2014/06/06 Python
python虚拟环境迁移方法
2019/01/03 Python
通过实例了解python property属性
2019/11/01 Python
keras 指定程序在某块卡上训练实例
2020/06/22 Python
python3处理word文档实例分析
2020/12/01 Python
联想中国官方商城:Lenovo China
2017/10/18 全球购物
意大利包包和行李箱销售网站:Bagaglio.it
2021/03/02 全球购物
回门宴新郎答谢词
2014/01/12 职场文书
优秀教师先进事迹
2014/01/22 职场文书
超市中秋节活动方案
2014/02/12 职场文书
2014年合同管理工作总结
2014/12/02 职场文书
结婚当天新郎保证书
2015/05/08 职场文书
硕士学位申请报告
2015/05/15 职场文书
聘任通知书
2015/09/21 职场文书
关于JavaScript 中 if包含逗号表达式
2021/11/27 Javascript