pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python的Tornado框架异步编程入门实例
Apr 24 Python
python使用MySQLdb访问mysql数据库的方法
Aug 03 Python
python学习笔记之调用eval函数出现invalid syntax错误问题
Oct 18 Python
Python selenium 三种等待方式解读
Sep 15 Python
python pandas dataframe 行列选择,切片操作方法
Apr 10 Python
基于python的socket实现单机五子棋到双人对战
Mar 24 Python
python基于json文件实现的gearman任务自动重启代码实例
Aug 13 Python
pytorch 修改预训练model实例
Jan 18 Python
pycharm sciview的图片另存为操作
Jun 01 Python
浅谈Django前端后端值传递问题
Jul 15 Python
Python环境使用OpenCV检测人脸实现教程
Oct 19 Python
教你使用TensorFlow2识别验证码
Jun 11 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
thinkphp实现数组分页示例
2014/04/13 PHP
PHP中16个高危函数整理
2019/09/19 PHP
PHP Pipeline 实现中间件的示例代码
2020/04/26 PHP
JavaScript OOP类与继承
2009/11/15 Javascript
javascript Onunload与Onbeforeunload使用小结
2009/12/31 Javascript
JQuery的$和其它JS发生冲突的快速解决方法
2014/01/24 Javascript
javascript框架设计之浏览器的嗅探和特征侦测
2015/06/23 Javascript
JavaScript对数组进行随机重排的方法
2015/07/22 Javascript
jQuery自定义滚动条完整实例
2016/01/08 Javascript
js 博客内容进度插件详解
2017/02/19 Javascript
jQuery Ajax向服务端传递数组参数值的实例代码
2017/09/03 jQuery
微信小程序官方动态自定义底部tabBar的例子
2019/09/04 Javascript
深入学习Vue nextTick的用法及原理
2019/10/08 Javascript
Vue axios 跨域请求无法带上cookie的解决
2020/09/08 Javascript
[27:39]Ti4 循环赛第二日 LGD vs Fnatic
2014/07/11 DOTA
[48:21]林俊杰圣堂刺客超神杀戮秀
2014/10/29 DOTA
[00:37]DOTA2上海特级锦标赛 Secert 战队宣传片
2016/03/03 DOTA
[01:52]DOTA2完美大师赛Vega战队趣味视频——kpii老师小课堂
2017/11/25 DOTA
[01:01:41]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Magma BO3 第二场 1月31日
2021/03/11 DOTA
Python命名空间详解
2014/08/18 Python
Python实现基于C/S架构的聊天室功能详解
2018/07/07 Python
Python文件读写常见用法总结
2019/02/22 Python
Python中flatten( ),matrix.A用法说明
2020/07/05 Python
解决pytorch 交叉熵损失输出为负数的问题
2020/07/07 Python
HTML5使用DOM进行自定义控制示例代码
2013/06/08 HTML / CSS
德国内衣、泳装和睡衣网上商店:Bigsize Dessous
2018/07/09 全球购物
Eton丹麦官网:精美的男式衬衫
2020/05/27 全球购物
【魔兽争霸3重制版】原版画面与淬火MOD画面对比
2021/03/26 魔兽争霸
食品厂厂长岗位职责
2014/01/30 职场文书
安全大检查反思材料
2014/01/31 职场文书
大学新闻系自荐书
2014/05/31 职场文书
工作能力自我评价2015
2015/03/05 职场文书
市场部岗位职责范本
2015/04/15 职场文书
iPhone13 Pro外观确定,升级4800万镜头,4月20日发新品
2021/04/15 数码科技
Mysql数据库中datetime、bigint、timestamp来表示时间选择,谁来存储时间效率最高
2021/08/23 MySQL
Pygame如何使用精灵和碰撞检测
2021/11/17 Python