PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之print详解
Sep 28 Python
python实现每次处理一个字符的三种方法
Oct 09 Python
介绍Python中的__future__模块
Apr 27 Python
Python 通过URL打开图片实例详解
Jun 01 Python
对python 匹配字符串开头和结尾的方法详解
Oct 27 Python
对Python3 * 和 ** 运算符详解
Feb 16 Python
Python3的高阶函数map,reduce,filter的示例详解
Jul 23 Python
Flask框架学习笔记之使用Flask实现表单开发详解
Aug 12 Python
python子线程退出及线程退出控制的代码
Oct 16 Python
idea2020手动安装python插件的实现方法
Jul 17 Python
python 实现aes256加密
Nov 27 Python
Python字符串对齐、删除字符串不需要的内容以及格式化打印字符
Jan 23 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
php面向对象全攻略 (九)访问类型
2009/09/30 PHP
PHP独立Session数据库存储操作类分享
2014/06/11 PHP
PHP抓取及分析网页的方法详解
2016/04/26 PHP
PHP自动补全表单的两种方法
2017/03/06 PHP
PHP检测一个数组有没有定义的方法步骤
2019/07/20 PHP
JS判断移动端访问设备并加载对应CSS样式
2014/06/13 Javascript
简介JavaScript中valueOf()方法的使用
2015/06/05 Javascript
JavaScript中反正弦函数Math.asin()的使用简介
2015/06/14 Javascript
Bootstrap实现input控件失去焦点时验证
2016/08/04 Javascript
Angularjs的Controller间通信机制实例分析
2016/11/07 Javascript
利用Query+bootstrap和js两种方式实现日期选择器
2017/01/10 Javascript
react.js 翻页插件实例代码
2017/01/19 Javascript
node.js 发布订阅模式的实例
2017/09/10 Javascript
浅谈webpack下的AOP式无侵入注入
2017/11/12 Javascript
vue-content-loader内容加载器的使用方法
2018/08/05 Javascript
微信小程序搭建(mpvue+mpvue-weui+fly.js)的详细步骤
2018/09/18 Javascript
echarts饼图各个板块之间的空隙如何实现
2020/12/01 Javascript
[33:09]完美世界DOTA2联赛循环赛 Forest vs DM BO2第二场 10.29
2020/10/29 DOTA
Python提取Linux内核源代码的目录结构实现方法
2016/06/24 Python
python实现kNN算法
2017/12/20 Python
Python+matplotlib实现华丽的文本框演示代码
2018/01/22 Python
对python读写文件去重、RE、set的使用详解
2018/12/11 Python
Python paramiko模块使用解析(实现ssh)
2019/08/30 Python
keras的siamese(孪生网络)实现案例
2020/06/12 Python
如何用Django处理gzip数据流
2021/01/29 Python
浅析几个CSS3常用功能的写法
2014/06/05 HTML / CSS
Html5插件教程之添加浏览器放大镜效果的商品橱窗
2016/01/07 HTML / CSS
瑰珀翠美国官网:Crabtree & Evelyn美国
2016/11/29 全球购物
婚庆公司的创业计划书
2014/01/22 职场文书
小学毕业感言500字
2014/02/28 职场文书
技术经济专业求职信
2014/09/03 职场文书
三年级学生期末评语
2014/12/26 职场文书
会议欢迎词
2015/01/23 职场文书
员工保密协议范本,您一定得收藏!很有用!
2019/08/08 职场文书
多属性、多分类MySQL模式设计
2021/04/05 MySQL
SONY600GR,国产收音机厂商永远的痛
2022/04/05 无线电