PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的登录和操作开心网脚本分享
Jul 09 Python
Python 不同对象比较大小示例探讨
Aug 21 Python
python使用点操作符访问字典(dict)数据的方法
Mar 16 Python
python简单实现基数排序算法
May 16 Python
linux下python抓屏实现方法
May 22 Python
pandas 实现字典转换成DataFrame的方法
Jul 04 Python
Python中的支持向量机SVM的使用(附实例代码)
Jun 26 Python
Python迭代器iterator生成器generator使用解析
Oct 24 Python
Python Django2.0集成Celery4.1教程
Nov 19 Python
python读取图像矩阵文件并转换为向量实例
Jun 18 Python
python collections模块的使用
Oct 16 Python
Python Django项目和应用的创建详解
Nov 27 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
php自动适应范围的分页代码
2008/08/05 PHP
php模拟登陆的实现方法分析
2015/01/09 PHP
innertext , insertadjacentelement , insertadjacenthtml , insertadjacenttext 等区别
2007/06/29 Javascript
jquery animate图片模向滑动示例代码
2011/01/26 Javascript
Array.prototype.concat不是通用方法反驳[译]
2012/09/20 Javascript
JS简单限制textarea内输入字符数量的方法
2015/10/14 Javascript
jquery利用拖拽方式在图片上添加热链接
2015/11/24 Javascript
神奇!js+CSS+DIV实现文字颜色渐变效果
2016/03/16 Javascript
jQuery动态改变多行文本框高度的方法
2016/09/07 Javascript
JS运动特效之完美运动框架实例分析
2018/01/24 Javascript
mpvue中使用flyjs全局拦截的实现代码
2018/09/13 Javascript
JavaScript继承的特性与实践应用深入详解
2018/12/30 Javascript
Vue起步(无cli)的啊教程详解
2019/04/11 Javascript
axios 实现post请求时把对象obj数据转为formdata
2019/10/31 Javascript
Vue2.4+新增属性.sync、$attrs、$listeners的具体使用
2020/03/08 Javascript
[01:25]2015国际邀请赛最佳短片奖——斧王《拆塔英雄:天赋异禀》
2015/09/22 DOTA
python爬虫入门教程之糗百图片爬虫代码分享
2014/09/02 Python
Python正则表达式和re库知识点总结
2019/02/11 Python
python使用tomorrow实现多线程的例子
2019/07/20 Python
python采集百度搜索结果带有特定URL的链接代码实例
2019/08/30 Python
python编写猜数字小游戏
2019/10/06 Python
解决 jupyter notebook 回车换两行问题
2020/04/15 Python
Python Json数据文件操作原理解析
2020/05/09 Python
python中get和post有什么区别
2020/06/19 Python
python 通过exifread读取照片信息
2020/12/24 Python
浅谈CSS3中display属性的Flex布局的方法
2017/08/14 HTML / CSS
NUK奶瓶美国官网:NUK美国
2016/09/26 全球购物
意大利单身交友网站:Meetic
2020/07/12 全球购物
优秀医生事迹材料
2014/02/12 职场文书
代理协议书
2014/04/22 职场文书
学校爱国卫生月活动总结
2014/06/25 职场文书
普通党员四风问题对照检查材料
2014/09/27 职场文书
收入及婚姻状况证明
2014/11/20 职场文书
初中学生操行评语
2014/12/26 职场文书
淘宝客服专员岗位职责
2015/04/07 职场文书
国际贸易实训总结
2015/08/03 职场文书