PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django教程笔记之中间件middleware详解
Aug 01 Python
selenium+python自动化测试之页面元素定位
Jan 23 Python
Python开发网站目录扫描器的实现
Feb 21 Python
浅谈python标准库--functools.partial
Mar 13 Python
Python pandas RFM模型应用实例详解
Nov 20 Python
Python单元测试模块doctest的具体使用
Feb 10 Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 Python
python3.7添加dlib模块的方法
Jul 01 Python
使用Python绘制台风轨迹图的示例代码
Sep 21 Python
python3从网络摄像机解析mjpeg http流的示例
Nov 13 Python
python自动化办公操作PPT的实现
Feb 05 Python
详解Python常用的魔法方法
Jun 03 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
PHP的范围解析操作符(::)的含义分析说明
2011/07/03 PHP
PHP函数之日期时间函数date()使用详解
2013/09/09 PHP
php通过exif_read_data函数获取图片的exif信息
2015/05/21 PHP
yii gridview实现时间段筛选功能
2017/08/15 PHP
php生成静态页面并实现预览功能
2019/06/27 PHP
flash调用js中的方法,让js传递变量给flash的办法及思路
2013/08/07 Javascript
node.js入门教程迷你书、node.js入门web应用开发完全示例
2014/04/06 Javascript
javascript简单比较日期大小的方法
2016/01/05 Javascript
JS中取二维数组中最大值的方法汇总
2016/04/17 Javascript
javascript验证手机号和实现星号(*)代替实例
2016/08/16 Javascript
JavaScript中全选、全不选、反选、无刷新删除、批量删除、即点即改入库(在yii框架中操作)的代码分享
2016/11/01 Javascript
js实现横向拖拽导航条功能
2017/02/17 Javascript
BootStrap表单宽度设置方法
2017/03/10 Javascript
详解JavaScript添加给定的标签选项
2018/09/17 Javascript
vue进入页面时滚动条始终在底部代码实例
2019/03/26 Javascript
利用webpack理解CommonJS和ES Modules的差异区别
2020/06/16 Javascript
jQuery使用hide()、toggle()函数实现相机品牌展示隐藏功能
2021/01/29 jQuery
用Python写一个无界面的2048小游戏
2016/05/24 Python
一个基于flask的web应用诞生 使用模板引擎和表单插件(2)
2017/04/11 Python
基于python时间处理方法(详解)
2017/08/14 Python
Python基本数据结构与用法详解【列表、元组、集合、字典】
2019/03/23 Python
[机器视觉]使用python自动识别验证码详解
2019/05/16 Python
python内存监控工具memory_profiler和guppy的用法详解
2019/07/29 Python
利用CSS3把图片变成灰色模式的实例代码
2016/09/06 HTML / CSS
如何让pre和textarea等HTML元素去掉滚动条自动换行自适应文本内容高度
2019/08/01 HTML / CSS
美国婚礼和派对礼品网站:Kate Aspen(新娘送礼会、迎婴派对)
2018/03/28 全球购物
在线课程:Skillshare
2019/04/02 全球购物
幼师岗位求职简历的自荐信格式
2013/09/21 职场文书
学生实习介绍信
2014/01/15 职场文书
情况说明书格式范文
2014/05/06 职场文书
英语专业自荐书
2014/06/13 职场文书
法律专业大学生职业生涯规划书:向目标一步步迈进
2014/09/22 职场文书
人大代表选举标语
2014/10/07 职场文书
政协会议宣传标语
2014/10/09 职场文书
王亚平太空授课观后感
2015/06/12 职场文书
高并发下Redis如何保持数据一致性(避免读后写)
2022/03/18 Redis