pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程编程中的join函数使用心得
Sep 02 Python
Python通过递归遍历出集合中所有元素的方法
Feb 25 Python
Python实现CET查分的方法
Mar 10 Python
python使用matplotlib模块绘制多条折线图、散点图
Apr 26 Python
用Python读取几十万行文本数据
Dec 24 Python
python3 selenium自动化测试 强大的CSS定位方法
Aug 23 Python
基于Python解密仿射密码
Oct 21 Python
Python在终端通过pip安装好包以后在Pycharm中依然无法使用的问题(三种解决方案)
Mar 10 Python
Pyspark获取并处理RDD数据代码实例
Mar 27 Python
Python 如何实现访问者模式
Jul 28 Python
Python Charles抓包配置实现流程图解
Sep 29 Python
Python实现图片指定位置加图片水印(附Pyinstaller打包exe)
Mar 04 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
全国FM电台频率大全 - 27 陕西省
2020/03/11 无线电
我的论坛源代码(九)
2006/10/09 PHP
PHP strtotime函数用法、实现原理和源码分析
2015/02/04 PHP
实例讲解php实现多线程
2019/01/27 PHP
JSON 数字排序多字段排序介绍
2013/09/18 Javascript
js判断字符长度以及中英文数字等
2013/12/31 Javascript
利用Query+bootstrap和js两种方式实现日期选择器
2017/01/10 Javascript
利用CSS、JavaScript及Ajax实现图片预加载的三大方法
2017/01/22 Javascript
jquery 回调操作实例分析【回调成功与回调失败的情况】
2019/09/27 jQuery
Javascript查看大图功能代码实现
2020/05/07 Javascript
在Python的Django框架中获取单个对象数据的简单方法
2015/07/17 Python
Tensorflow卷积神经网络实例
2018/05/24 Python
Python操作mongodb数据库进行模糊查询操作示例
2018/06/09 Python
Python面向对象类的继承实例详解
2018/06/27 Python
python对日志进行处理的实例代码
2018/10/06 Python
Python之循环结构
2019/01/15 Python
基于python及pytorch中乘法的使用详解
2019/12/27 Python
new_zeros() pytorch版本的转换方式
2020/02/18 Python
python GUI库图形界面开发之PyQt5控件QTableWidget详细使用方法与属性
2020/02/25 Python
Python定义一个函数的方法
2020/06/15 Python
利用python绘制中国地图(含省界、河流等)
2020/09/21 Python
使用Python爬取小姐姐图片(beautifulsoup法)
2021/02/11 Python
兰蔻俄罗斯官方网站:Lancome俄罗斯
2019/12/09 全球购物
为数据库创建索引都需要注意些什么
2012/07/17 面试题
网络安全方面的面试题
2016/01/07 面试题
一套英文Java笔试题面试题
2016/04/21 面试题
大学生通用个人自我评价
2014/04/27 职场文书
安全生产月活动总结
2014/05/04 职场文书
公民授权委托书范本
2014/09/17 职场文书
教师党的群众路线对照检查材料
2014/09/24 职场文书
学校领导四风问题整改措施思想汇报
2014/10/09 职场文书
学校2014年度工作总结
2014/12/06 职场文书
2015关于重阳节的演讲稿
2015/03/20 职场文书
劳动仲裁撤诉申请书
2015/05/18 职场文书
redis的list数据类型相关命令介绍及使用
2022/01/18 Redis
Python创建SQL数据库流程逐步讲解
2022/09/23 Python