pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python实现ip查询示例
Mar 26 Python
非递归的输出1-N的全排列实例(推荐)
Apr 11 Python
利用numpy+matplotlib绘图的基本操作教程
May 03 Python
pandas中Timestamp类用法详解
Dec 11 Python
Python即时网络爬虫项目启动说明详解
Feb 23 Python
Python根据已知邻接矩阵绘制无向图操作示例
Jun 23 Python
在Python中表示一个对象的方法
Jun 25 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
Oct 12 Python
利用matplotlib实现根据实时数据动态更新图形
Dec 13 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 Python
解决python 执行sql语句时所传参数含有单引号的问题
Jun 06 Python
pandas实现导出数据的四种方式
Dec 13 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
支持中文字母数字、自定义字体php验证码代码
2012/02/27 PHP
jQuery+Ajax+PHP“喜欢”评级功能实现代码
2015/10/08 PHP
深入解析PHP的Yii框架中的event事件机制
2016/03/17 PHP
关于 Laravel Redis 多个进程同时取队列问题详解
2017/12/25 PHP
基于JQuery的cookie插件
2010/04/07 Javascript
javascript之典型高阶函数应用介绍二
2013/01/10 Javascript
Javascript 数组排序详解
2014/10/22 Javascript
教你如何终止JQUERY的$.AJAX请求
2016/02/23 Javascript
基于jQuery.validate及Bootstrap的tooltip开发气泡样式的表单校验组件思路详解
2016/07/18 Javascript
VUE 更好的 ajax 上传处理 axios.js实现代码
2017/05/10 Javascript
EasyUI Datebox 日期验证之开始日期小于结束时间
2017/05/19 Javascript
JavaScript 中的12种循环遍历方法【总结】
2018/05/31 Javascript
使用js实现将后台传入的json数据放在前台显示
2018/08/06 Javascript
微信小程序实现日历功能
2018/11/27 Javascript
Angular6使用forRoot() 注册单一实例服务问题
2019/08/27 Javascript
Ant design vue table 单击行选中 勾选checkbox教程
2020/10/24 Javascript
python中requests爬去网页内容出现乱码问题解决方法介绍
2017/10/25 Python
TensorFlow saver指定变量的存取
2018/03/10 Python
Python实现读写INI配置文件的方法示例
2018/06/09 Python
10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径
2019/08/12 Python
如何在Anaconda中打开python自带idle
2020/09/21 Python
python如何调用php文件中的函数详解
2020/12/29 Python
纯CSS3发光分享按钮的实现教程
2014/09/06 HTML / CSS
详解css3中的伪类before和after常见用法
2020/11/17 HTML / CSS
三只松鼠官方旗舰店:全网坚果销售第1
2017/11/25 全球购物
英国排名第一的礼品体验公司:Red Letter Days
2018/08/16 全球购物
亚马逊海外购:亚马逊美国、英国、日本、德国直邮
2021/03/18 全球购物
Java里面如何把一个Array数组转换成Collection, List
2013/07/26 面试题
软件测试企业面试试卷
2016/07/13 面试题
学历公证委托书
2014/04/09 职场文书
小学师德师风演讲稿
2014/09/02 职场文书
2015新生加入学生会自荐书
2015/03/24 职场文书
2015年行风建设工作总结
2015/05/15 职场文书
酒桌上的开场白
2015/06/01 职场文书
2015年七夕情人节感言
2015/08/03 职场文书
2016年公务员六五普法心得体会
2016/01/21 职场文书