pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中表达式i += x与i = i + x是否等价
Feb 08 Python
python3 模拟登录v2ex实例讲解
Jul 13 Python
深入理解python中函数传递参数是值传递还是引用传递
Nov 07 Python
遗传算法python版
Mar 19 Python
Python实现的多叉树寻找最短路径算法示例
Jul 30 Python
使用python的pexpect模块,实现远程免密登录的示例
Feb 14 Python
Python+OpenCV 实现图片无损旋转90°且无黑边
Dec 12 Python
基于Django实现日志记录报错信息
Dec 17 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 Python
基于Python爬虫采集天气网实时信息
Jun 05 Python
Elasticsearch py客户端库安装及使用方法解析
Sep 14 Python
Python为何不支持switch语句原理详解
Oct 21 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
用PHP将数据导入到Foxmail的实现代码
2010/09/05 PHP
php ajax 静态分页过程形式
2011/09/02 PHP
JoshChen_php新手进阶高手不可或缺的规范介绍
2013/08/16 PHP
Smarty变量调节器失效的解决办法
2014/08/20 PHP
php类自动加载器实现方法
2015/07/28 PHP
PHP实现上传多文件示例代码
2017/02/20 PHP
php获取微信基础接口凭证Access_token
2018/08/23 PHP
laravel-admin表单提交隐藏一些数据,回调时获取数据的方法
2019/10/08 PHP
jQuery 1.0.4 - New Wave Javascript(js源文件)
2007/01/15 Javascript
javascript与CSS复习(《精通javascript》)
2010/06/29 Javascript
js中的string.format函数代码
2020/08/11 Javascript
jQuery图片预加载 等比缩放实现代码
2011/10/04 Javascript
利用浏览器全屏api实现js全屏
2014/01/16 Javascript
JavaScript实现列出数组中最长的连续数
2014/12/29 Javascript
angularjs 处理多个异步请求方法汇总
2015/01/06 Javascript
JQuery实现的按钮倒计时效果
2015/12/23 Javascript
jquery正则表达式验证(手机号、身份证号、中文名称)
2015/12/31 Javascript
JavaScript获取select中text值的方法
2017/02/13 Javascript
Angular.js中处理页面闪烁的方法详解
2017/03/09 Javascript
js canvas实现放大镜查看图片功能
2017/06/08 Javascript
详解基于 axios 的 Vue 项目 http 请求优化
2017/09/04 Javascript
一次记住JavaScript的6个正则表达式方法
2018/02/22 Javascript
vue生成token并保存到本地存储中
2018/07/17 Javascript
Mint UI实现A-Z字母排序的城市选择列表
2018/12/28 Javascript
JavaScript HTML DOM元素 节点操作汇总
2019/07/29 Javascript
解决Vue大括号字符换行踩的坑
2020/11/09 Javascript
Python计时相关操作详解【time,datetime】
2017/05/26 Python
Python爬虫之Selenium实现关闭浏览器
2020/12/04 Python
MYPROTEIN澳大利亚官方网站:欧洲运动营养品牌
2019/06/26 全球购物
材料采购员岗位职责
2013/12/17 职场文书
辩论赛主持词
2014/03/18 职场文书
2015年审计人员工作总结
2015/05/26 职场文书
劳动合同变更协议书范本
2019/04/18 职场文书
用基于python的appium爬取b站直播消费记录
2021/04/17 Python
健身房被搭讪?用python写了个小米计时器助人为乐
2021/06/08 Python
死磕 java同步系列之synchronized解析
2021/06/28 Java/Android