pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用正则表达式检测密码强度源码分享
Jun 11 Python
Python version 2.7 required, which was not found in the registry
Aug 26 Python
Python中用altzone()方法处理时区的教程
May 22 Python
python3安装pip3(install pip3 for python 3.x)
Apr 03 Python
python 列表中[ ]中冒号‘:’的作用
Apr 30 Python
将python文件打包成EXE应用程序的方法
May 22 Python
Python3离线安装Requests模块问题
Oct 13 Python
详解python中docx库的安装过程
Nov 08 Python
使用jupyter Nodebook查看函数或方法的参数以及使用情况
Apr 14 Python
Python判断远程服务器上Excel文件是否被人打开的方法
Jul 13 Python
详解Python中openpyxl模块基本用法
Feb 23 Python
Pytest中conftest.py的用法
Jun 27 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
自己动手,丰衣足食 - 短波框形天线制作
2021/03/01 无线电
解析php利用正则表达式解决采集内容排版的问题
2013/06/20 PHP
php验证码的制作思路和实现方法
2015/11/12 PHP
php上传图片生成缩略图(GD库)
2016/01/06 PHP
crontab无法执行php的解决方法
2016/01/25 PHP
CentOS 7.2 下编译安装PHP7.0.10+MySQL5.7.14+Nginx1.10.1的方法详解(mini版本)
2016/09/01 PHP
详解PHP中array_rand函数的使用方法
2016/09/11 PHP
php使用函数pathinfo()、parse_url()和basename()解析URL
2016/11/25 PHP
Yii2学习笔记之汉化yii设置表单的描述(属性标签attributeLabels)
2017/02/07 PHP
PHP 布尔值的自增与自减的实现方法
2018/05/03 PHP
仿迅雷焦点广告效果(JQuery版)
2008/11/19 Javascript
IE Firefox 使用自定义标签的区别
2009/10/15 Javascript
jQuery EasyUI API 中文文档 - DataGrid数据表格
2011/11/17 Javascript
jquery插件制作简单示例说明
2012/02/03 Javascript
js简单实现表单中点击按钮动态增加输入框数量的方法
2015/08/18 Javascript
JS实现漂亮的淡蓝色滑动门效果代码
2015/09/23 Javascript
BootStrap文件上传样式超好看【持续更新】
2016/05/10 Javascript
JQuery 设置checkbox值二次无效的解决方法
2016/07/22 Javascript
详解前后端分离之VueJS前端
2017/05/24 Javascript
Vue自定义指令详解
2017/07/28 Javascript
webpack 4.0.0-beta.0版本新特性介绍
2018/02/10 Javascript
jQuery实现的简单歌词滚动功能示例
2019/01/07 jQuery
javascript网页随机点名实现过程解析
2019/10/15 Javascript
Python实现控制台输入密码的方法
2015/05/29 Python
Windows下Anaconda的安装和简单使用方法
2018/01/04 Python
判断Threading.start新线程是否执行完毕的实例
2020/05/02 Python
python实现批量命名照片
2020/06/18 Python
利用Python实现自动扫雷小脚本
2020/12/17 Python
某公司Java工程师面试题笔试题
2016/03/27 面试题
日语专业推荐信
2013/11/12 职场文书
个人作风纪律整顿整改措施
2014/10/25 职场文书
班主任2015新年寄语
2014/12/08 职场文书
会计试用期自我评价
2015/03/10 职场文书
追悼会悼词大全
2015/06/23 职场文书
PyMongo 查询数据的实现
2021/06/28 Python
Redis基本数据类型Set常用操作命令
2022/06/01 Redis