pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python警察与小偷的实现之一客户端与服务端通信实例
Oct 09 Python
详解Python中的array数组模块相关使用
Jul 05 Python
利用python实现简单的循环购物车功能示例代码
Jul 05 Python
理解python中生成器用法
Dec 20 Python
Python实现自定义函数的5种常见形式分析
Jun 16 Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 Python
python散点图实例之随机漫步
Aug 27 Python
Python基于opencv实现的简单画板功能示例
Mar 04 Python
详解Python数据分析--Pandas知识点
Mar 23 Python
对Python _取log的几种方式小结
Jul 25 Python
MATLAB数学建模之画图汇总
Jul 16 Python
python3美化表格数据输出结果的实现代码
Apr 14 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
PHP保存带BOM文件的方法
2015/02/12 PHP
php 使用html5实现多文件上传实例
2016/10/24 PHP
PHP对象、模式与实践之高级特性分析
2016/12/08 PHP
如何实现iframe(嵌入式帧)的自适应高度
2006/07/26 Javascript
js中设置元素class的三种方法小结
2011/08/28 Javascript
跨域请求之jQuery的ajax jsonp的使用解惑
2011/10/09 Javascript
window.navigate 与 window.location.href 的使用区别介绍
2013/09/21 Javascript
使用JavaScript根据图片获取条形码的方法
2017/07/04 Javascript
webpack学习笔记之代码分割和按需加载的实例详解
2017/07/20 Javascript
jQuery实现切换隐藏与显示同时切换图标功能
2017/10/29 jQuery
nodeJS服务器的创建和重新启动的实现方法
2018/05/12 NodeJs
angularjs中判断ng-repeat是否迭代完的实例
2018/09/12 Javascript
[18:16]sakonoko 2017年卡尔集锦
2018/02/06 DOTA
[02:23]1个至宝=115个英雄特效 最“绿”至宝拉比克“魔导师密钥”登场
2018/12/29 DOTA
python检测远程端口是否打开的方法
2015/03/14 Python
在Debian下配置Python+Django+Nginx+uWSGI+MySQL的教程
2015/04/25 Python
详解python之简单主机批量管理工具
2017/01/27 Python
Tensorflow 同时载入多个模型的实例讲解
2018/07/27 Python
python实现C4.5决策树算法
2018/08/29 Python
更新pip3与pyttsx3文字语音转换的实现方法
2019/08/08 Python
关于tf.nn.dynamic_rnn返回值详解
2020/01/20 Python
Python如何使用bokeh包和geojson数据绘制地图
2020/03/21 Python
CSS3教程(5):网页背景图片
2009/04/02 HTML / CSS
芝加哥牛排公司:Chicago Steak Company
2018/10/31 全球购物
档案检查欢迎词
2014/01/13 职场文书
致100米运动员广播稿
2014/02/14 职场文书
2014年党员公开承诺践诺书
2014/03/25 职场文书
社会实践活动总结报告
2014/04/29 职场文书
环保口号大全
2014/06/12 职场文书
幼儿园端午节活动总结
2015/05/05 职场文书
2015年小学一年级班主任工作总结
2015/05/21 职场文书
拉贝日记观后感
2015/06/05 职场文书
呼啸山庄读书笔记
2015/06/29 职场文书
导游词之宿迁乾隆行宫
2019/10/15 职场文书
Python打包为exe详细教程
2021/05/18 Python
Django集成富文本编辑器summernote的实现步骤
2021/05/31 Python