pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍Python中的JSON使用
Apr 28 Python
简介二分查找算法与相关的Python实现示例
Aug 26 Python
Python实现图片尺寸缩放脚本
Mar 10 Python
关于Python正则表达式 findall函数问题详解
Mar 22 Python
python opencv实现运动检测
Jul 10 Python
python中import与from方法总结(推荐)
Mar 21 Python
Python面向对象总结及类与正则表达式详解
Apr 18 Python
浅析Python3中的对象垃圾收集机制
Jun 06 Python
python requests库爬取豆瓣电视剧数据并保存到本地详解
Aug 10 Python
浅析Python 字符编码与文件处理
Sep 24 Python
pyspark对Mysql数据库进行读写的实现
Dec 30 Python
pycharm进入时每次都是insert模式的解决方式
Feb 05 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
PHP定时执行计划任务的多种方法小结
2011/12/19 PHP
php使用PDO执行SQL语句的方法分析
2017/02/16 PHP
PHP实现的Redis多库选择功能单例类
2017/07/27 PHP
解决laravel资源加载路径设置的问题
2019/10/14 PHP
页面中body onload 和 window.onload 冲突的问题的解决
2009/07/01 Javascript
javascript判断两个IP地址是否在同一个网段的实现思路
2013/12/13 Javascript
JavaScript操作Oracle数据库示例
2015/03/06 Javascript
Jquery插件实现点击获取验证码后60秒内禁止重新获取
2015/03/13 Javascript
js焦点文字滚动效果代码分享
2015/08/25 Javascript
轻松掌握JavaScript享元模式
2016/08/27 Javascript
微信小程序开发之Tabbar实例详解
2017/01/09 Javascript
深入理解node.js之path模块
2017/05/03 Javascript
将 vue 生成的 js 上传到七牛的实例
2017/07/28 Javascript
JavaScript实现图片拖曳效果
2017/09/08 Javascript
JavaScript EventEmitter 背后的秘密 完整版
2018/03/29 Javascript
原生JS实现的碰撞检测功能示例
2018/05/18 Javascript
微信小程序使用wxParse解析html的实现示例
2018/08/30 Javascript
详解vue路由
2020/08/05 Javascript
[54:56]DOTA2上海特级锦标赛主赛事日 - 5 总决赛Liquid VS Secret第三局
2016/03/06 DOTA
在Python中使用turtle绘制多个同心圆示例
2019/11/23 Python
Pytorch技巧:DataLoader的collate_fn参数使用详解
2020/01/08 Python
JetBrains PyCharm(Community版本)的下载、安装和初步使用图文教程详解
2020/03/19 Python
Python yield生成器和return对比代码实例
2020/04/20 Python
python实现简单遗传算法
2020/09/18 Python
HTML5 拖放(Drag 和 Drop)详解与实例代码
2017/09/14 HTML / CSS
html5简介及新增功能介绍
2020/05/18 HTML / CSS
世界上最大的罕见唱片、CD和音乐纪念品网上商店:991.com
2018/05/03 全球购物
英国书籍、CD、DVD和游戏的第一道德零售商:Awesome Books
2020/02/22 全球购物
应用化学专业职业生涯规划书
2013/12/31 职场文书
实习心得体会
2014/01/02 职场文书
网站创业计划书
2014/04/30 职场文书
2014办公室副主任四风对照检查材料思想汇报
2014/09/20 职场文书
导游词范文
2015/02/13 职场文书
地雷战观后感
2015/06/09 职场文书
2015小学毕业班工作总结
2015/07/21 职场文书
公司趣味运动会开幕词
2016/03/04 职场文书