pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
详解Python3中yield生成器的用法
Aug 20 Python
numpy自动生成数组详解
Dec 15 Python
彻底搞懂Python字符编码
Jan 23 Python
python list元素为tuple时的排序方法
Apr 18 Python
Python3用tkinter和PIL实现看图工具
Jun 21 Python
Python3编码问题 Unicode utf-8 bytes互转方法
Oct 26 Python
Python企业编码生成系统总体系统设计概述
Jul 26 Python
python提取xml里面的链接源码详解
Oct 15 Python
python编写计算器功能
Oct 25 Python
Python代码生成视频的缩略图的实例讲解
Dec 22 Python
python实现从尾到头打印单链表操作示例
Feb 22 Python
python虚拟环境模块venv使用及示例
Mar 04 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
高分R级DC动画剧《哈莉·奎茵》第二季正式预告首发
2020/04/09 欧美动漫
PHP与MySQL开发中页面乱码的产生与解决
2008/03/27 PHP
PHP中基于ts与nts版本- vc6和vc9编译版本的区别详解
2013/04/26 PHP
解析PHP实现下载文件的两种方法
2013/07/05 PHP
让PHP显示Facebook的粉丝数量方法
2014/01/08 PHP
PHP解析html类库simple_html_dom的转码bug
2014/05/22 PHP
ThinkPHP3.1新特性之内容解析输出详解
2014/06/19 PHP
php提示Failed to write session data错误的解决方法
2014/12/17 PHP
php验证码的制作思路和实现方法
2015/11/12 PHP
PHP删除数组中特定元素的两种方法
2019/02/28 PHP
为Plesk PHP7启用Oracle OCI8扩展方法总结
2019/03/29 PHP
HTML中Select不用Disabled实现ReadOnly的效果
2008/04/07 Javascript
IE6不能修改NAME问题的解决方法
2010/09/03 Javascript
js和jquery对dom节点的操作(创建/追加)
2013/04/21 Javascript
Enter转换为Tab的小例子(兼容IE,Firefox)
2013/11/14 Javascript
javascript伸缩菜单栏实现代码分享
2015/11/12 Javascript
任意Json转成无序列表的方法示例
2016/12/09 Javascript
Js apply方法详解
2017/02/16 Javascript
JS+canvas绘制的动态机械表动画效果
2017/09/12 Javascript
JS获取当前地理位置的方法
2017/10/25 Javascript
微信小程序swiper组件用法实例分析【附源码下载】
2017/12/07 Javascript
微信小程序使用checkbox显示多项选择框功能【附源码下载】
2017/12/11 Javascript
微信小程序自定义组件实现tabs选项卡功能
2018/07/14 Javascript
NodeJs 实现简单WebSocket即时通讯的示例代码
2019/08/05 NodeJs
在vue中封装方法以及多处引用该方法详解
2020/08/14 Javascript
浅谈vue websocket nodeJS 进行实时通信踩到的坑
2020/09/22 NodeJs
[03:00]《DAC最前线》之欧美新秀VS老将
2015/02/01 DOTA
对python中的argv和argc使用详解
2018/12/15 Python
解析PyCharm Python运行权限问题
2020/01/08 Python
Python 模拟生成动态产生验证码图片的方法
2020/02/01 Python
tensorflow使用指定gpu的方法
2020/02/04 Python
马德里著名的运动鞋商店:NOIRFONCE
2019/04/12 全球购物
用Python写一个简易版弹球游戏
2021/04/13 Python
MySQL下使用Inplace和Online方式创建索引的教程
2021/05/26 MySQL
十大公认最好看的动漫:《咒术回战》在榜,《钢之炼金术师》第一
2022/03/18 日漫
windows server2008 开启端口的实现方法
2022/06/25 Servers