pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量生成本地ip地址的方法
Mar 23 Python
Python的Django框架中模板碎片缓存简介
Jul 24 Python
python中map()与zip()操作方法
Feb 27 Python
Python模糊查询本地文件夹去除文件后缀的实例(7行代码)
Nov 09 Python
Python中property属性实例解析
Feb 10 Python
对numpy中数组元素的统一赋值实例
Apr 04 Python
python:接口间数据传递与调用方法
Dec 17 Python
在Python中关于使用os模块遍历目录的实现方法
Jan 03 Python
python面向对象实现名片管理系统文件版
Apr 26 Python
python numpy 常用随机数的产生方法的实现
Aug 21 Python
pytorch查看通道数 维数 尺寸大小方式
May 26 Python
Python Flask框架实现简单加法工具过程解析
Jun 03 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
利用yahoo汇率接口实现实时汇率转换示例 汇率转换器
2014/01/14 PHP
使用PHP实现阻止用户上传成人照片或者裸照
2014/12/25 PHP
PHP 年月日的三级联动实例代码
2017/05/24 PHP
Javascript客户端脚本的设计和应用
2006/08/21 Javascript
用原生JS对AJAX做简单封装的实例代码
2016/07/13 Javascript
解决Vue页面固定滚动位置的处理办法
2017/07/13 Javascript
js实现数组和对象的深浅拷贝
2017/09/30 Javascript
JS实现遍历不规则多维数组的方法
2018/03/21 Javascript
angularJS实现不同视图同步刷新详解
2018/10/09 Javascript
利用Vue-draggable组件实现Vue项目中表格内容的拖拽排序
2019/06/07 Javascript
vue 解决遍历对象显示的顺序不对问题
2019/11/07 Javascript
JavaScript switch语句使用方法简介
2019/12/30 Javascript
通过javascript实现扫雷游戏代码实例
2020/02/09 Javascript
JS实现拖动模糊框特效
2020/08/25 Javascript
vue 动态创建组件的两种方法
2020/12/31 Vue.js
Python的randrange()方法使用教程
2015/05/15 Python
深入理解python try异常处理机制
2016/06/01 Python
python jieba分词并统计词频后输出结果到Excel和txt文档方法
2018/02/11 Python
python实现事件驱动
2018/11/21 Python
python+selenium select下拉选择框定位处理方法
2019/08/24 Python
python3实现从kafka获取数据,并解析为json格式,写入到mysql中
2019/12/23 Python
Python进程的通信Queue、Pipe实例分析
2020/03/30 Python
python图片指定区域替换img.paste函数的使用
2020/04/09 Python
详解numpy.ndarray.reshape()函数的参数问题
2020/10/13 Python
运行Python编写的程序方法实例
2020/10/21 Python
详解Python中openpyxl模块基本用法
2021/02/23 Python
澳大利亚拥有最佳跳伞降落点和最好服务的跳伞项目运营商:Skydive Australia
2018/03/05 全球购物
For Art’s Sake官网:手工制作的奢华眼镜
2018/12/15 全球购物
迎八一活动主题
2014/01/31 职场文书
十八届三中全会报告学习材料
2014/02/17 职场文书
农村葬礼主持词
2014/03/31 职场文书
党员干部观看《周恩来四个昼夜》思想汇报
2014/09/10 职场文书
领导干部作风建设自查报告
2014/10/23 职场文书
申报材料格式
2014/12/30 职场文书
2015年度个人业务工作总结
2015/04/27 职场文书
关于公司年会的开幕词
2016/03/04 职场文书