PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用glob和rmtree删除目录子目录及所有文件的例子
Nov 21 Python
python中logging库的使用总结
Oct 18 Python
Python实现随机生成手机号及正则验证手机号的方法
Apr 25 Python
在python image 中安装中文字体的实现方法
Aug 22 Python
python实现获取单向链表倒数第k个结点的值示例
Oct 24 Python
Pytorch在NLP中的简单应用详解
Jan 08 Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 Python
Python集合操作方法详解
Feb 09 Python
python datetime处理时间小结
Apr 16 Python
Python selenium模块实现定位过程解析
Jul 09 Python
python 实现压缩和解压缩的示例
Sep 22 Python
Python下使用Trackbar实现绘图板
Oct 27 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
VML绘图板②脚本--VMLgraph.js、XMLtool.js
2006/10/09 PHP
php操作XML、读取数据和写入数据的实现代码
2014/08/15 PHP
PHP观察者模式实例分析【对比JS观察者模式】
2019/05/22 PHP
javascript 图片上传预览-兼容标准
2009/06/01 Javascript
JavaScript 反科里化 this [译]
2012/09/20 Javascript
减少访问DOM的次数提升javascript性能
2014/02/24 Javascript
原生javascript实现简单的datagrid数据表格
2015/01/02 Javascript
jquery实现select下拉框美化特效代码分享
2015/08/18 Javascript
JS实现带提示的星级评分效果完整实例
2015/10/30 Javascript
AngularJS表达式讲解及示例代码
2016/08/16 Javascript
JavaScript中object和Object的区别(详解)
2017/02/27 Javascript
JavaScript实现的搜索及高亮显示功能示例
2017/08/14 Javascript
使用vue-router切换页面时,获取上一页url以及当前页面url的方法
2019/05/06 Javascript
微信小程序云开发之云函数详解
2019/05/16 Javascript
create-react-app中添加less支持的实现
2019/11/15 Javascript
JS造成内存泄漏的几种情况实例分析
2020/03/02 Javascript
vue实现简单的登录弹出框
2020/10/26 Javascript
JavaScript TAB栏切换效果的示例
2020/11/05 Javascript
Python数据结构之双向链表的定义与使用方法示例
2018/01/16 Python
Python Flask前后端Ajax交互的方法示例
2018/07/31 Python
Django中如何使用sass的方法步骤
2019/07/09 Python
ubuntu上安装python的实例方法
2019/09/30 Python
python常用排序算法的实现代码
2019/11/08 Python
Python numpy.zero() 初始化矩阵实例
2019/11/27 Python
Python+Redis实现布隆过滤器
2019/12/08 Python
利用python批量爬取百度任意类别的图片的实现方法
2020/10/07 Python
美国波道夫·古德曼百货官网:Bergdorf Goodman
2017/11/07 全球购物
goodhealth官方海外旗舰店:新西兰国民营养师
2017/12/15 全球购物
Topshop美国官网:英国快速时尚品牌
2019/05/16 全球购物
日本订房网站,预订日本星级酒店/温泉旅馆:Relux(支持中文)
2020/01/03 全球购物
电子商务自荐书范文
2014/01/04 职场文书
民主生活会对照检查材料
2014/09/22 职场文书
高三数学教学反思
2016/02/18 职场文书
深入解析NumPy中的Broadcasting广播机制
2021/05/30 Python
MYSQL优化之数据表碎片整理详解
2022/04/03 MySQL
python如何为list实现find方法
2022/05/30 Python