PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Mysql数据库操作 Perl操作Mysql数据库
Jan 12 Python
python实现根据窗口标题调用窗口的方法
Mar 13 Python
Python里disconnect UDP套接字的方法
Apr 23 Python
深入解析Python中的lambda表达式的用法
Aug 28 Python
python爬虫实现教程转换成 PDF 电子书
Feb 19 Python
Python中enumerate函数代码解析
Oct 31 Python
python3解析库lxml的安装与基本使用
Jun 27 Python
详解python中的time和datetime的常用方法
Jul 08 Python
对python 中re.sub,replace(),strip()的区别详解
Jul 22 Python
Django发送邮件功能实例详解
Sep 02 Python
Python如何解除一个装饰器
Aug 07 Python
pycharm安装深度学习pytorch的d2l包失败问题解决
Mar 25 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
星际RPG字典
2020/03/04 星际争霸
PHP获取http请求的头信息实现步骤
2012/12/16 PHP
如何利用PHP执行.SQL文件
2013/07/05 PHP
ThinkPHP实现支付宝接口功能实例
2014/12/02 PHP
ecshop 2.72如何修改后台访问地址
2015/03/03 PHP
PHP实现多文件上传的方法
2015/07/08 PHP
php简单实现文件或图片强制下载的方法
2016/12/06 PHP
Laravel如何使用Redis共享Session
2018/02/23 PHP
PHP7 其他修改
2021/03/09 PHP
利用WebBrowser彻底解决Web打印问题(包括后台打印)
2009/06/22 Javascript
基于jquery的鼠标拖动效果代码
2012/05/30 Javascript
jQuery实现表头固定效果的实例代码
2013/05/24 Javascript
详解nodejs中exports和module.exports的区别
2017/02/17 NodeJs
Vue中fragment.js使用方法详解
2017/03/09 Javascript
JS实现的JSON数组去重算法示例
2018/04/11 Javascript
浅谈webpack 构建性能优化策略小结
2018/06/13 Javascript
微信小程序实现手指触摸画板
2018/07/09 Javascript
微信小程序中使用ECharts 异步加载数据实现图表功能
2018/07/13 Javascript
VueJs里利用CryptoJs实现加密及解密的方法示例
2019/04/29 Javascript
layer扩展打开/关闭动画的方法
2019/09/23 Javascript
Vue.js组件使用props传递数据的方法
2019/10/19 Javascript
vue中的mescroll搜索运用及各种填坑处理
2019/10/30 Javascript
你可能从未使用过的11+个JavaScript特性(小结)
2020/01/08 Javascript
JavaScript设计模型Iterator实例解析
2020/01/22 Javascript
[01:22]DOTA2神秘商店携大量周边降临完美大师赛
2017/11/07 DOTA
[52:05]EG vs OG 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
python多线程方式执行多个bat代码
2016/06/07 Python
python 调用win32pai 操作cmd的方法
2017/05/28 Python
Python字符串拼接的几种方法整理
2017/08/02 Python
python中字符串的操作方法大全
2018/06/03 Python
通过Python扫描代码关键字并进行预警的实现方法
2020/05/24 Python
荷兰家电销售网站:Welhof
2020/12/08 全球购物
小班开学寄语
2014/04/04 职场文书
高三励志标语
2014/06/05 职场文书
甜品店创业计划书
2014/08/14 职场文书
先进人物事迹材料
2014/12/29 职场文书