pytorch MSELoss计算平均的实现方法


Posted in Python onMay 12, 2021

给定损失函数的输入y,pred,shape均为bxc。

若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

补充:PyTorch中MSELoss的使用

参数

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。

reduction的可选参数有:'none' 、'mean' 、'sum'

reduction='none':求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。

reduction='mean':求所有对应位置差的平方的均值,返回的是一个标量。

reduction='sum':求所有对应位置差的平方的和,返回的是一个标量。

更多可查看官方文档​

举例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

则输出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

则输出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

则输出:

tensor(37.)

在反向传播时的使用

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。

这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。

但具体设置为'sum'还是'mean'都是可以的。

若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。

若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
使用Python神器对付12306变态验证码
Jan 05 Python
用Python编写简单的微博爬虫
Mar 04 Python
Linux下通过python访问MySQL、Oracle、SQL Server数据库的方法
Apr 23 Python
python读取与写入csv格式文件的示例代码
Dec 16 Python
Windows上使用Python增加或删除权限的方法
Apr 24 Python
在Mac下使用python实现简单的目录树展示方法
Nov 01 Python
Pyqt5如何让QMessageBox按钮显示中文示例代码
Apr 11 Python
Python中字符串与编码示例代码
May 20 Python
在Python中使用MySQL--PyMySQL的基本使用方法
Nov 19 Python
python 利用百度API识别图片文字(多线程版)
Dec 14 Python
Python 内存管理机制全面分析
Jan 16 Python
python多线程方法详解
Jan 18 Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
You might like
PHP Curl出现403错误的解决办法
2014/05/29 PHP
PHP的serialize序列化数据以及JSON格式化数据分析
2015/10/10 PHP
MSN消息提示类
2006/09/05 Javascript
使用jQuery清空file文件域的解决方案
2013/04/12 Javascript
用javascript添加控件自定义属性解析
2013/11/25 Javascript
js动态添加onclick事件可传参数与不传参数
2014/07/29 Javascript
js中匿名函数的创建与调用方法分析
2014/12/19 Javascript
jQuery实现跨域iframe接口方法调用
2015/03/14 Javascript
js+html5通过canvas指定开始和结束点绘制线条的方法
2015/06/05 Javascript
JavaScript实现输入框(密码框)出现提示语
2016/01/12 Javascript
简单总结JavaScript中的String字符串类型
2016/05/26 Javascript
JS中parseInt()和map()用法分析
2016/12/16 Javascript
微信小程序 两种为对象属性赋值的方式详解
2017/02/23 Javascript
Angualrjs和bootstrap相结合实现数据表格table
2017/03/30 Javascript
详解NODEJS基于FFMPEG视频推流测试
2017/11/17 NodeJs
JavaScript的setter与getter方法
2017/11/29 Javascript
jQuery表单元素过滤选择器用法实例分析
2019/02/20 jQuery
环形加载进度条封装(Vue插件版和原生js版)
2019/12/04 Javascript
Vue的data、computed、watch源码浅谈
2020/04/04 Javascript
jQuery中event.target和this的区别详解
2020/08/13 jQuery
antd中table展开行默认展示,且不需要前边的加号操作
2020/11/02 Javascript
Python for循环生成列表的实例
2018/06/15 Python
python保存文件方法小结
2018/07/27 Python
在Python 字典中一键对应多个值的实例
2019/02/03 Python
Python 异常的捕获、异常的传递与主动抛出异常操作示例
2019/09/23 Python
Python垃圾回收机制三种实现方法
2020/04/27 Python
python 实现单例模式的5种方法
2020/09/23 Python
python tkinter的消息框模块(messagebox,simpledialog)
2020/11/07 Python
【HTML5】3D模型--百行代码实现旋转立体魔方实例
2016/12/16 HTML / CSS
short s1 = 1; s1 = s1 + 1;有什么错? short s1 = 1; s1 += 1;有什么错?
2014/09/26 面试题
公司委托书范本
2014/04/04 职场文书
优秀班组申报材料
2014/12/25 职场文书
贷款承诺书
2015/01/20 职场文书
大雁塔英文导游词
2015/02/10 职场文书
2016优秀青年志愿者事迹材料
2016/02/25 职场文书
Python实现天气查询软件
2021/06/07 Python