pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
ssh批量登录并执行命令的python实现代码
May 25 Python
Python字符串和文件操作常用函数分析
Apr 08 Python
python中argparse模块用法实例详解
Jun 03 Python
Python函数的周期性执行实现方法
Aug 13 Python
python原类、类的创建过程与方法详解
Jul 19 Python
Python学习笔记之Django创建第一个数据库模型的方法
Aug 07 Python
Python爬取爱奇艺电影信息代码实例
Nov 26 Python
python循环嵌套的多种使用方法解析
Nov 29 Python
Python Numpy库常见用法入门教程
Jan 16 Python
Python检测端口IP字符串是否合法
Jun 05 Python
详解基于Scrapy的IP代理池搭建
Sep 29 Python
Pytorch中的数据集划分&正则化方法
May 27 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
如何做到多笔资料的同步
2006/10/09 PHP
php读取数据库信息的几种方法
2008/05/24 PHP
PHP封装的多文件上传类实例与用法详解
2017/02/07 PHP
php实现与python进行socket通信的方法示例
2017/08/30 PHP
laravel 框架结合关联查询 when()用法分析
2019/11/22 PHP
Yii2框架中一些折磨人的坑
2019/12/15 PHP
jquery 学习之二 属性相关
2010/11/23 Javascript
JQuery 绑定select标签的onchange事件,弹出选择的值,并实现跳转、传参
2011/01/06 Javascript
分享jQuery封装好的一些常用操作
2016/07/28 Javascript
jQuery简单设置文本框回车事件的方法
2016/08/01 Javascript
JavaScript登录验证码的实现
2016/10/27 Javascript
JS实现购物车特效
2017/02/02 Javascript
移动设备手势事件库Touch.js使用详解
2017/08/18 Javascript
angularjs 获取默认选中的单选按钮的value方法
2018/02/28 Javascript
js获取form表单中name属性的值
2019/02/27 Javascript
JS document对象简单用法完整示例
2020/01/14 Javascript
如何通过javaScript去除字符串两端的空白字符
2020/02/06 Javascript
[08:07]DOTA2每周TOP10 精彩击杀集锦vol.8
2014/06/25 DOTA
Python实现购物车购物小程序
2018/04/18 Python
python多线程http压力测试脚本
2019/06/25 Python
django基于restframework的CBV封装详解
2019/08/08 Python
python matplotlib 绘图 和 dpi对应关系详解
2020/03/14 Python
python爬虫容易学吗
2020/06/02 Python
Python列表推导式实现代码实例
2020/09/09 Python
pandas按照列的值排序(某一列或者多列)
2020/12/13 Python
详解CSS3 弹性布局快速入门
2019/06/06 HTML / CSS
"引用"与多态的关系
2013/02/01 面试题
英语师范专业毕业生自荐信
2013/09/21 职场文书
客房主管岗位职责
2013/12/09 职场文书
党的群众路线教育实践活动通讯稿
2014/09/10 职场文书
建筑质检员岗位职责
2015/04/08 职场文书
诚信高考倡议书
2019/06/24 职场文书
经典法律座右铭(50句)
2019/08/15 职场文书
redis实现排行榜功能
2021/05/24 Redis
MySQL数据库10秒内插入百万条数据的实现
2021/11/01 MySQL
python Tkinter模块使用方法详解
2022/04/07 Python