pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Django通用视图中的函数包装
Jul 21 Python
python装饰器初探(推荐)
Jul 21 Python
详解如何用OpenCV + Python 实现人脸识别
Oct 20 Python
Python进阶之自定义对象实现切片功能
Jan 07 Python
python图像处理入门(一)
Apr 04 Python
python初学者,用python实现基本的学生管理系统(python3)代码实例
Apr 10 Python
Django模型序列化返回自然主键值示例代码
Jun 12 Python
opencv之为图像添加边界的方法示例
Dec 26 Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 Python
Python类的动态绑定实现原理
Mar 21 Python
Python pandas对excel的操作实现示例
Jul 21 Python
解决jupyter notebook启动后没有token的坑
Apr 24 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
php打造属于自己的MVC框架
2012/03/07 PHP
destoon安装出现Internal Server Error的解决方法
2014/06/21 PHP
PHP实现阳历到农历转换的类实例
2015/03/07 PHP
PHP中的流(streams)浅析
2015/07/02 PHP
laravel框架与其他框架的详细对比
2019/10/23 PHP
简单实用的反馈表单无刷新提交带验证
2013/11/15 Javascript
js 左右悬浮对联广告特效代码
2014/12/12 Javascript
js实现跨域的多种方法
2015/12/25 Javascript
全面了解js中的script标签
2016/07/04 Javascript
完美的js div拖拽实例代码
2016/09/24 Javascript
js 动态生成json对象、时时更新json对象的方法
2016/12/02 Javascript
flexslider.js实现移动端轮播
2017/02/05 Javascript
AngularJS的Filter的示例详解
2017/03/07 Javascript
node.js中debug模块的简单介绍与使用
2017/04/25 Javascript
node.js中cluster的使用教程
2017/06/09 Javascript
详谈js模块化规范
2017/07/07 Javascript
js HTML5 canvas绘制图片的方法
2017/09/08 Javascript
webpack-url-loader 解决项目中图片打包路径问题
2019/02/15 Javascript
javascript实现点亮灯泡特效示例
2019/10/15 Javascript
vue element-ui实现动态面包屑导航
2019/12/23 Javascript
基于Python3.6+splinter实现自动抢火车票
2018/09/25 Python
解决python3捕获cx_oracle抛出的异常错误问题
2018/10/18 Python
Python中调用其他程序的方式详解
2019/08/06 Python
python爬虫模块URL管理器模块用法解析
2020/02/03 Python
new_zeros() pytorch版本的转换方式
2020/02/18 Python
详解Python中的路径问题
2020/09/02 Python
文明学生事迹材料
2014/01/29 职场文书
投资协议书范本
2014/04/21 职场文书
幼儿评语大全
2014/04/30 职场文书
村主任群众路线教育实践活动个人对照检查材料思想汇报
2014/10/01 职场文书
违反工作规定检讨书范文
2014/12/14 职场文书
医院科室评语
2015/01/04 职场文书
承兑汇票延期证明
2015/06/23 职场文书
2015国庆节放假通知范文
2015/07/30 职场文书
2019职场实习报告该怎么写?
2019/07/01 职场文书
15个值得收藏的JavaScript函数
2021/09/15 Javascript