PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python每隔N秒运行指定函数的方法
Mar 16 Python
python获取指定路径下所有指定后缀文件的方法
May 26 Python
python实现的DES加密算法和3DES加密算法实例
Jun 03 Python
python3下使用cv2.imwrite存储带有中文路径图片的方法
May 10 Python
Django+JS 实现点击头像即可更改头像的方法示例
Dec 26 Python
详解Python的三种可变参数
May 08 Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 Python
Django框架模型简单介绍与使用分析
Jul 18 Python
python实现批量修改服务器密码的方法
Aug 13 Python
django框架单表操作之增删改实例分析
Dec 16 Python
tensorflow求导和梯度计算实例
Jan 23 Python
Python 里最强的地图绘制神器
Mar 01 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
php 强制下载文件实现代码
2013/10/28 PHP
php实现使用正则将文本中的网址转换成链接标签
2014/12/03 PHP
PHP 获取 ping 时间的实现方法
2017/09/29 PHP
Yii 框架使用数据库(databases)的方法示例
2020/05/19 PHP
php7连接MySQL实现简易查询程序的方法
2020/10/13 PHP
js 调整select 位置的函数
2008/02/21 Javascript
javascript Excel操作知识点
2009/04/24 Javascript
javascript globalStorage类代码
2009/06/04 Javascript
指定位置如果有图片显示图片,无图片显示广告的JS
2010/06/05 Javascript
extjs tabpanel限制选项卡数量实现思路及代码
2013/04/02 Javascript
结合JQ1.9通过js正则判断各种浏览器版本的方法
2013/12/30 Javascript
js实现用户离开页面前提示是否离开此页面的方法(包括浏览器按钮事件)
2015/07/18 Javascript
js实现超简单的展开、折叠目录代码
2015/08/28 Javascript
node.js抓取并分析网页内容有无特殊内容的js文件
2015/11/17 Javascript
JavaScript之cookie技术详解
2016/11/18 Javascript
bootstrap table配置参数例子
2017/01/05 Javascript
Vue2.0实现购物车功能
2017/06/05 Javascript
vue-i18n结合Element-ui的配置方法
2019/05/20 Javascript
jQuery实现小火箭返回顶部特效
2020/02/03 jQuery
Vue+Bootstrap收藏(点赞)功能逻辑与具体实现
2020/10/22 Javascript
[07:03]显微镜下的DOTA2第九期——430圣堂刺客杀戮秀
2014/06/20 DOTA
[00:16]热血竞技场
2019/03/06 DOTA
Python numpy实现数组合并实例(vstack,hstack)
2018/01/09 Python
对django中render()与render_to_response()的区别详解
2018/10/16 Python
python3.5安装python3-tk详解
2019/04/26 Python
Python使用itchat模块实现简单的微信控制电脑功能示例
2019/08/26 Python
python通过文本在一个图中画多条线的实例
2020/02/21 Python
Python如何在main中调用函数内的函数方式
2020/06/01 Python
使用css3实现超炫的loading加载动画效果
2014/05/07 HTML / CSS
如何用canvas实现在线签名的示例代码
2018/07/10 HTML / CSS
应届大专生求职信
2014/06/26 职场文书
学校师德师风整改措施
2014/10/27 职场文书
在职证明书模板
2015/06/15 职场文书
《窃读记》教学反思
2016/02/18 职场文书
高三生物教学反思
2016/02/22 职场文书
python非标准时间的转换
2021/07/25 Python