pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 中的列表解析和生成表达式
Mar 10 Python
Python中的文件和目录操作实现代码
Mar 13 Python
Python内置函数的用法实例教程
Sep 08 Python
python中os操作文件及文件路径实例汇总
Jan 15 Python
Python创建系统目录的方法
Mar 11 Python
python中引用与复制用法实例分析
Jun 04 Python
python操作字典类型的常用方法(推荐)
May 16 Python
浅谈python字符串方法的简单使用
Jul 18 Python
python多行字符串拼接使用小括号的方法
Mar 19 Python
pytorch-RNN进行回归曲线预测方式
Jan 14 Python
Keras Convolution1D与Convolution2D区别说明
May 22 Python
python 爬虫基本使用——统计杭电oj题目正确率并排序
Oct 26 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
PHP队列用法实例
2014/11/05 PHP
PHP保存session到memcache服务器的方法
2016/01/19 PHP
php基于自定义函数记录log日志方法
2017/07/21 PHP
PDO::exec讲解
2019/01/28 PHP
学习YUI.Ext 第七天--关于View&JSONView
2007/03/10 Javascript
JavaScript 对象链式操作测试代码
2010/04/25 Javascript
JS打开新窗口的2种方式
2013/04/18 Javascript
JS中的数组的sort方法使用示例
2014/01/22 Javascript
JS实现固定在右下角可展开收缩DIV层的方法
2015/02/13 Javascript
Wireshark基本介绍和学习TCP三次握手
2016/08/15 Javascript
详解Vue生命周期的示例
2017/03/10 Javascript
原生js的ajax和解决跨域的jsonp(实例讲解)
2017/10/16 Javascript
javascript+css3开发打气球小游戏完整代码
2017/11/28 Javascript
NodeJS搭建HTTP服务器的实现步骤
2018/10/12 NodeJs
ES6的Fetch异步请求的实现方法
2018/12/07 Javascript
JQuery实现简单的复选框树形结构图示例【附源码下载】
2019/07/16 jQuery
node.js中npm包管理工具用法分析
2020/02/14 Javascript
python获取图片颜色信息的方法
2015/03/18 Python
Python 实现淘宝秒杀的示例代码
2018/01/02 Python
使用Python写一个小游戏
2018/04/02 Python
Python使用random.shuffle()打乱列表顺序的方法
2018/11/08 Python
Python openpyxl读取单元格字体颜色过程解析
2019/09/03 Python
使用tensorflow框架在Colab上跑通猫狗识别代码
2020/04/26 Python
Python用户自定义异常的实现
2020/12/25 Python
美国知名的女性服饰品牌:LOFT(洛芙特)
2016/08/05 全球购物
广州品高软件.net笔面试题目
2012/04/18 面试题
如何进行Linux分区优化
2016/09/13 面试题
2013届毕业生求职信范文
2013/11/20 职场文书
运动会开幕式主持词
2014/03/28 职场文书
关于晚自习早退的检讨书
2014/09/13 职场文书
银行开户授权委托书格式
2014/10/10 职场文书
电影建国大业观后感
2015/06/01 职场文书
2015年统计员个人工作总结
2015/07/23 职场文书
2019感恩宣传标语!
2019/07/05 职场文书
Django框架之路由用法
2022/06/10 Python
鸿蒙3.0体验感怎么样? 鸿蒙3.0系统评测向
2022/08/14 数码科技