PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python警察与小偷的实现之一客户端与服务端通信实例
Oct 09 Python
Python实现的多线程端口扫描工具分享
Jan 21 Python
Python新手在作用域方面经常容易碰到的问题
Apr 03 Python
在Python的Flask框架中使用日期和时间的教程
Apr 21 Python
浅析Python装饰器以及装饰器模式
May 28 Python
python使用Turtle库绘制动态钟表
Nov 19 Python
Django之Mode的外键自关联和引用未定义的Model方法
Dec 15 Python
使用python telnetlib批量备份交换机配置的方法
Jul 25 Python
django foreignkey外键使用的例子 相当于left join
Aug 06 Python
python2和python3应该学哪个(python3.6与python3.7的选择)
Oct 01 Python
Django ModelForm组件原理及用法详解
Oct 12 Python
Python一行代码实现自动发邮件功能
May 30 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
PHP调用三种数据库的方法(1)
2006/10/09 PHP
PHP操作mysql函数详解,mysql和php交互函数
2011/05/19 PHP
PHP对象递归引用造成内存泄漏分析
2014/08/28 PHP
php 在字符串指定位置插入新字符的简单实现
2016/06/28 PHP
PHP中empty,isset,is_null用法和区别
2017/02/19 PHP
PHP实现图片的等比缩放和Logo水印功能示例
2017/05/04 PHP
Swoole源码中如何查询Websocket的连接问题详解
2020/08/30 PHP
给网站上的广告“加速”显示的方法
2007/04/08 Javascript
JQuery 网站换肤功能实现代码
2009/11/02 Javascript
jQuery .attr()和.removeAttr()方法操作元素属性示例
2013/07/16 Javascript
使用POST方式弹出窗口的两种方法示例介绍
2014/01/29 Javascript
JavaScript跨平台的开源框架NativeScript
2015/03/24 Javascript
js获取元素的外链样式的简单实现方法
2016/06/06 Javascript
Vue数据驱动模拟实现2
2017/01/11 Javascript
Vue表单验证插件的制作过程
2017/04/01 Javascript
详解利用jsx写vue组件的方法示例
2017/07/17 Javascript
基于jQuery实现的单行公告活动轮播效果
2017/08/23 jQuery
layer子层给父层页面元素赋值,以达到向父层页面传值的效果实例
2017/09/22 Javascript
基于openlayers4实现点的扩散效果
2020/08/17 Javascript
动态加载JavaScript文件的3种方式
2018/05/05 Javascript
vue 中swiper的使用教程
2018/05/22 Javascript
vue-cli 为项目设置别名的方法
2019/10/15 Javascript
javascript设计模式 ? 单例模式原理与应用实例分析
2020/04/09 Javascript
vue实现移动端返回顶部
2020/10/12 Javascript
以windows service方式运行Python程序的方法
2015/06/03 Python
详解python的webrtc库实现语音端点检测
2017/05/31 Python
pycharm新建一个python工程步骤
2019/07/16 Python
python pandas移动窗口函数rolling的用法
2020/02/29 Python
Python虚拟环境virtualenv创建及使用过程图解
2020/12/08 Python
德国综合购物网站:OTTO
2018/11/13 全球购物
2014年中秋寄语
2014/08/11 职场文书
考试作弊万能检讨书
2014/10/19 职场文书
党支部2014年度工作总结
2014/12/04 职场文书
2015年度电厂个人工作总结
2015/05/13 职场文书
初中语文教学反思范文
2016/03/03 职场文书
学者《孟子》名人名言
2019/08/09 职场文书