pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
学习python的几条建议分享
Feb 10 Python
玩转python爬虫之爬取糗事百科段子
Feb 17 Python
python字典的常用操作方法小结
May 16 Python
Python+django实现简单的文件上传
Aug 17 Python
Python中的is和==比较两个对象的两种方法
Sep 06 Python
python用fsolve、leastsq对非线性方程组求解
Dec 15 Python
Python读取csv文件分隔符设置方法
Jan 14 Python
Python 类的魔法属性用法实例分析
Nov 21 Python
python返回数组的索引实例
Nov 28 Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 Python
Pytho爬虫中Requests设置请求头Headers的方法
Sep 22 Python
详解matplotlib中pyplot和面向对象两种绘图模式之间的关系
Jan 22 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
两个开源的Php输出Excel文件类
2010/02/08 PHP
CI(CodeIgniter)框架配置
2014/06/10 PHP
PHP空值检测函数与方法汇总
2017/11/19 PHP
PHP+MySQL实现输入页码跳转到指定页面功能示例
2018/06/01 PHP
laravel orm 关联条件查询代码
2019/10/21 PHP
javascript js cookie的存储,获取和删除
2007/12/29 Javascript
Ext grid 添加右击菜单
2009/11/26 Javascript
Javascript 读书笔记索引贴
2010/01/11 Javascript
jquery多行滚动/向左或向上滚动/响应鼠标实现思路及代码
2013/01/23 Javascript
js函数中onmousedown和onclick的区别和联系探讨
2013/05/19 Javascript
Jquery 类网页微信二维码图块滚动效果具体实现
2013/10/14 Javascript
javascript处理表单示例(javascript提交表单)
2014/04/28 Javascript
javascript白色简洁计算器
2015/05/04 Javascript
AngularJS实用开发技巧(推荐)
2016/07/13 Javascript
遍历json 对象的属性并且动态添加属性的实现
2016/12/02 Javascript
JS中的作用域链
2017/03/01 Javascript
详解Vuex管理登录状态
2017/11/13 Javascript
Angular4 组件通讯方法大全(推荐)
2018/07/12 Javascript
vue实现Excel文件的上传与下载功能的两种方式
2019/06/28 Javascript
详解vuex的简单todolist例子
2019/07/14 Javascript
Vue的生命周期操作示例
2019/09/17 Javascript
[04:50]2019DOTA2高校联赛秋季赛四强集锦
2019/12/27 DOTA
用于统计项目中代码总行数的Python脚本分享
2015/04/21 Python
Python两台电脑实现TCP通信的方法示例
2019/05/06 Python
python openvc 裁剪、剪切图片 提取图片的行和列
2019/09/19 Python
pytorch 模型的train模式与eval模式实例
2020/02/20 Python
Django Form常用功能及代码示例
2020/10/13 Python
python 使用paramiko模块进行封装,远程操作linux主机的示例代码
2020/12/03 Python
中国跨境电子商务网站:NewFrog
2018/03/10 全球购物
2014年国培研修感言
2014/03/09 职场文书
公司寄语大全
2014/04/10 职场文书
公司节能减排倡议书
2014/05/14 职场文书
11.9消防日宣传标语
2014/10/08 职场文书
mysql5.7使用binlog 恢复数据的方法
2021/06/03 MySQL
Vue过滤器(filter)实现及应用场景详解
2021/06/15 Vue.js
深入浅析python3 依赖倒置原则(示例代码)
2021/07/09 Python