PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python命令行参数解析OptionParser类用法实例
Oct 09 Python
浅谈Python实现贪心算法与活动安排问题
Dec 19 Python
Python标准库笔记struct模块的使用
Feb 22 Python
python实现批量修改图片格式和尺寸
Jun 07 Python
在Python中通过getattr获取对象引用的方法
Jan 21 Python
python钉钉机器人运维脚本监控实例
Feb 20 Python
利用python和百度地图API实现数据地图标注的方法
May 13 Python
python可视化篇之流式数据监控的实现
Aug 07 Python
PYTHON如何读取和写入EXCEL里面的数据
Oct 28 Python
python TK库简单应用(实时显示子进程输出)
Oct 29 Python
Python Pandas 对列/行进行选择,增加,删除操作
May 17 Python
Python实战之用tkinter库做一个鼠标模拟点击器
Apr 27 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
excellent!――ASCII Art(由目标图象生成ascii)
2007/02/20 PHP
PHP限制页面只能在微信自带浏览器访问的代码
2014/01/15 PHP
PHP防止刷新重复提交页面的示例代码
2015/11/11 PHP
php集成动态口令认证
2016/07/21 PHP
PHP实现带重试功能的curl连接示例
2016/07/28 PHP
JS 页面自动加载函数(兼容多浏览器)
2009/05/18 Javascript
JavaScript Event学习第七章 事件属性
2010/02/07 Javascript
深入理解javascript中return的作用
2013/12/30 Javascript
JS表格组件神器bootstrap table详解(基础版)
2015/12/08 Javascript
JS组件系列之Bootstrap table表格组件神器【终结篇】
2016/05/10 Javascript
简单的分页代码js实现
2016/05/17 Javascript
使用React实现轮播效果组件示例代码
2016/09/05 Javascript
JS创建对象的写法示例
2016/11/04 Javascript
JavaScript闭包原理与用法实例分析
2018/08/10 Javascript
vuejs前后端数据交互之从后端请求数据的实例
2018/08/11 Javascript
浅谈Vue.use的使用
2018/08/29 Javascript
jQuery.validate.js表单验证插件的使用代码详解
2018/10/22 jQuery
JS实现的类似微信聊天效果示例
2019/01/29 Javascript
JavaScript浅层克隆与深度克隆示例详解
2020/09/01 Javascript
[45:25]OG vs EG 2019国际邀请赛淘汰赛 胜者组 BO3 第一场 8.22
2019/09/05 DOTA
基于python3 类的属性、方法、封装、继承实例讲解
2017/09/19 Python
Python hashlib模块的使用示例
2020/10/09 Python
css3圆角边框和边框阴影示例
2014/05/05 HTML / CSS
关于css兼容性问题及一些常见问题汇总
2016/05/03 HTML / CSS
CSS3实现可翻转的hover效果
2018/05/23 HTML / CSS
Agoda.com官方网站:便宜预订全球酒店,高达80%的折扣
2018/04/04 全球购物
馥蕾诗美国官网:Fresh美国
2019/10/09 全球购物
初始化了一个没有run()方法的线程类,是否会出错?
2014/03/27 面试题
服装厂厂长职责
2013/12/16 职场文书
大学毕业生推荐信
2014/07/09 职场文书
公司委托书格式范本
2014/09/16 职场文书
小学语文复习计划
2015/01/19 职场文书
工作时间证明
2015/06/15 职场文书
关于开学的感想
2015/08/10 职场文书
HAM-2000摩机图
2021/04/22 无线电
给numpy.array增加维度的超简单方法
2021/06/02 Python