PyTorch的SoftMax交叉熵损失和梯度用法


Posted in Python onJanuary 15, 2020

在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算

关于softmax_cross_entropy求导的过程,可以参考HERE

示例

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# 对data求梯度, 用于反向传播
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

# 多分类标签 one-hot格式
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

输出:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:损失 和 input's grad:输入的梯度
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

注意

对于单输入的loss 和 grad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#分别令不同索引位置label为1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)

其输出:

# 第一组:
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

# 第二组:
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

# 第三组:
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
解释:
对于输入数据 tensor([[ 1., 2., 3.]]) softmax之后的结果如下
tensor([[ 0.0900, 0.2447, 0.6652]])
交叉熵求解梯度推导公式可知 s[0, 0]-1, s[0, 1]-1, s[0, 2]-1 是上面三组label对应的输入数据梯度
"""

pytorch提供的softmax, 和log_softmax 关系

# 官方提供的softmax实现
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: # 等同于 pytorch 提供的 softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax # 计算结果相同
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

# 与log_softmax关系
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])

官方提供的softmax交叉熵求解结果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

# 交叉熵的计算方式之一
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss输入是log_softmax, target是非one-hot格式的label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)

"""
# 交叉熵计算方式二
loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted
#CrossEntropyLoss适用于分类问题的损失函数
#input:没有softmax过的nn.output, target是非one-hot格式label
loss = loss_fn(data, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "\ngrad: ", data.grad)
"""
"""

输出

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])

通过和示例的输出对比, 发现两者是一样的

以上这篇PyTorch的SoftMax交叉熵损失和梯度用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
查看Python安装路径以及安装包路径小技巧
Apr 28 Python
python实现折半查找和归并排序算法
Apr 14 Python
Python中生成Epoch的方法
Apr 26 Python
Python有序查找算法之二分法实例分析
Dec 11 Python
python调用百度语音REST API
Aug 30 Python
Python解析命令行读取参数之argparse模块
Jul 26 Python
python匿名函数用法实例分析
Aug 03 Python
Django模板语言 Tags使用详解
Sep 09 Python
Python 读取xml数据,cv2裁剪图片实例
Mar 10 Python
Django项目uwsgi+Nginx保姆级部署教程实现
Apr 19 Python
python爬虫实例之获取动漫截图
May 31 Python
python lambda的使用详解
Feb 26 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 #Python
Python 中@property的用法详解
Jan 15 #Python
Python字符串中删除特定字符的方法
Jan 15 #Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 #Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 #Python
You might like
使用PHP实现密保卡功能实现代码<打包下载直接运行>
2011/10/09 PHP
nginx+thinkphp下解决不支持pathinfo模式
2015/07/01 PHP
php判断目录存在的简单方法
2019/09/26 PHP
laravel-admin 后台表格筛选设置默认的查询日期方法
2019/10/03 PHP
JS中不为人知的五种声明Number的方式简要概述
2013/02/22 Javascript
javascript弹出页面回传值的方法
2015/01/28 Javascript
浅析js中substring和substr的方法
2015/11/09 Javascript
javascript每日必学之基础入门
2016/02/16 Javascript
JS提示:Uncaught SyntaxError:Unexpected token ) 错误的解决方法
2016/08/19 Javascript
JavaScript基础——使用Canvas绘图
2016/11/02 Javascript
javascript解析ajax返回的xml和json格式数据实例详解
2017/01/05 Javascript
angular 动态组件类型详解(四种组件类型)
2017/02/22 Javascript
jquery实现左右滑动式轮播图
2017/03/02 Javascript
nodejs对express中next函数的一些理解
2017/09/08 NodeJs
JS实现区分中英文并统计字符个数的方法示例
2018/06/09 Javascript
JS获取浏览器地址栏的多个参数值的任意值实例代码
2018/07/24 Javascript
在axios中使用params传参的时候传入数组的方法
2018/09/25 Javascript
vue双向绑定及观察者模式详解
2019/03/19 Javascript
vue history 模式打包部署在域名的二级目录的配置指南
2019/07/02 Javascript
python生成随机密码或随机字符串的方法
2015/07/03 Python
pandas全表查询定位某个值所在行列的方法
2018/04/12 Python
Python处理菜单消息操作示例【基于win32ui模块】
2018/05/09 Python
python写入文件自动换行问题的方法
2019/07/05 Python
Python读取pdf表格写入excel的方法
2021/01/22 Python
html5 canvas-2.用canvas制作一个猜字母的小游戏
2013/01/07 HTML / CSS
西班牙英格列斯百货官网:El Corte Inglés
2016/09/25 全球购物
联想加拿大官方网站:Lenovo Canada
2018/04/05 全球购物
Wilson体育用品官网:美国著名运动器材品牌
2019/05/12 全球购物
巴西Bo.Bô官方在线商店:经营奢侈品时尚业务
2020/03/16 全球购物
面向对象设计的原则是什么
2013/02/13 面试题
Java如何获得ResultSet的总行数
2016/09/03 面试题
转让协议书范本
2014/04/15 职场文书
五五普法心得体会
2014/09/04 职场文书
婚礼答谢礼品
2015/01/20 职场文书
摩登时代观后感
2015/06/03 职场文书
MySQL使用IF语句及用case语句对条件并结果进行判断 
2022/09/23 MySQL