基于MSELoss()与CrossEntropyLoss()的区别详解


Posted in Python onJanuary 02, 2020

基于pytorch来讲

MSELoss()多用于回归问题,也可以用于one_hotted编码形式,

CrossEntropyLoss()名字为交叉熵损失函数,不用于one_hotted编码形式

MSELoss()要求batch_x与batch_y的tensor都是FloatTensor类型

CrossEntropyLoss()要求batch_x为Float,batch_y为LongTensor类型

(1)CrossEntropyLoss() 举例说明:

比如二分类问题,最后一层输出的为2个值,比如下面的代码:

class CNN (nn.Module ) :
  def __init__ ( self , hidden_size1 , output_size , dropout_p) :
    super ( CNN , self ).__init__ ( )
    self.hidden_size1 = hidden_size1
    self.output_size = output_size
    self.dropout_p = dropout_p
    
    self.conv1 = nn.Conv1d ( 1,8,3,padding =1) 
    self.fc1 = nn.Linear (8*500, self.hidden_size1 )
    self.out = nn.Linear (self.hidden_size1,self.output_size ) 
 
  
  def forward ( self , encoder_outputs ) :
    cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2) 
    cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一个dropout
    cnn_out = cnn_out.view (-1,8*500) 
    output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
    output = self.out ( ouput_1)
    return output

最后的输出结果为:

基于MSELoss()与CrossEntropyLoss()的区别详解

上面一个tensor为output结果,下面为target,没有使用one_hotted编码。

训练过程如下:

cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
              weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
 
def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
  cnn_output = cnn( input_variable )
  print(cnn_output)
  print(target_variable)
  loss = criterion ( cnn_output , target_variable)
  cnn_optimizer.zero_grad ()
  loss.backward( )
  cnn_optimizer.step( )
  #print('loss: ',loss.item())
  return loss.item() #返回损失

说明CrossEntropyLoss()是output两位为one_hotted编码形式,但target不是one_hotted编码形式。

(2)MSELoss() 举例说明:

网络结构不变,但是标签是one_hotted编码形式。下面的图仅做说明,网络结构不太对,出来的预测也不太对。

基于MSELoss()与CrossEntropyLoss()的区别详解

如果target不是one_hotted编码形式会报错,报的错误如下。

基于MSELoss()与CrossEntropyLoss()的区别详解

目前自己理解的两者的区别,就是这样的,至于多分类问题是不是也是样的有待考察。

以上这篇基于MSELoss()与CrossEntropyLoss()的区别详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 时间操作例子和时间格式化参数小结
Apr 24 Python
Python调用C语言开发的共享库方法实例
Mar 18 Python
Python编程中字符串和列表的基本知识讲解
Oct 14 Python
python监控文件或目录变化
Jun 07 Python
Python实现爬虫从网络上下载文档的实例代码
Jun 13 Python
对Python 3.2 迭代器的next函数实例讲解
Oct 18 Python
Python HTML解析模块HTMLParser用法分析【爬虫工具】
Apr 05 Python
python多线程与多进程及其区别详解
Aug 08 Python
Pandas+Matplotlib 箱式图异常值分析示例
Dec 09 Python
python3 使用openpyxl将mysql数据写入xlsx的操作
May 15 Python
python 实现一个图形界面的汇率计算器
Nov 09 Python
Python3.8官网文档之类的基础语法阅读
Sep 04 Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
Python面向对象原理与基础语法详解
Jan 02 #Python
Pytorch 的损失函数Loss function使用详解
Jan 02 #Python
Python面向对象封装操作案例详解 II
Jan 02 #Python
You might like
业余方法DIY电子管FM收音机
2021/03/02 无线电
PHP中使用crypt()实现用户身份验证的代码
2012/09/05 PHP
PHP中的命名空间相关概念浅析
2015/01/22 PHP
PHP中array_keys和array_unique函数源码的分析
2016/02/26 PHP
php写一个函数,实现扫描并打印出自定目录下(含子目录)所有jpg文件名
2017/05/26 PHP
php数组函数array_push()、array_pop()及array_shift()简单用法示例
2020/01/26 PHP
一个对于js this关键字的问题
2007/01/09 Javascript
JavaScript中的prototype使用说明
2010/04/13 Javascript
js控制的回到页面顶端goTop的代码实现
2013/03/20 Javascript
JS循环遍历JSON数据的方法
2014/07/08 Javascript
原生Ajax 和jQuery Ajax的区别示例分析
2014/12/17 Javascript
JavaScript中property和attribute的区别详细介绍
2015/03/03 Javascript
JS仿淘宝实现的简单滑动门效果代码
2015/10/14 Javascript
plupload+artdialog实现多平台上传文件
2016/07/19 Javascript
javascript循环链表之约瑟夫环的实现方法
2017/01/16 Javascript
bootstrap Validator 模态框、jsp、表单验证 Ajax提交功能
2017/02/17 Javascript
基于Vue生产环境部署详解
2017/09/15 Javascript
详解使用vue-admin-template的优化历程
2018/05/20 Javascript
bootstrap下拉分页样式 带跳转页码
2018/12/29 Javascript
微信小程序前端自定义分享的实现方法
2019/06/13 Javascript
Vue 技巧之控制父类的 slot
2020/02/24 Javascript
jQuery 选择方法及$(this)用法实例分析
2020/05/19 jQuery
解决VUE项目localhost端口服务器拒绝连接,只能用127.0.0.1的问题
2020/08/14 Javascript
Python新手入门最容易犯的错误总结
2017/04/24 Python
Python实现进程同步和通信的方法
2018/01/02 Python
python使用xlrd模块读取xlsx文件中的ip方法
2019/01/11 Python
python实现网页自动签到功能
2019/01/21 Python
python批量修改交换机密码的示例
2020/09/22 Python
详解python日志输出使用配置文件格式
2021/02/10 Python
介绍一下MD5加密算法
2016/11/12 面试题
大客户经理岗位职责
2015/04/09 职场文书
通知函格式范文
2015/04/27 职场文书
2015年税务稽查工作总结
2015/05/26 职场文书
公司食堂管理制度
2015/08/05 职场文书
微信小程序scroll-view不能左右滑动问题的解决方法
2021/07/09 Javascript
TypeScript实用技巧 Nominal Typing名义类型详解
2022/09/23 Javascript