Pytorch中Softmax和LogSoftmax的使用详解


Posted in Python onJune 05, 2021

一、函数解释

1.Softmax函数常用的用法是指定参数dim就可以:

(1)dim=0:对每一列的所有元素进行softmax运算,并使得每一列所有元素和为1。

(2)dim=1:对每一行的所有元素进行softmax运算,并使得每一行所有元素和为1。

class Softmax(Module):
    r"""Applies the Softmax function to an n-dimensional input Tensor
    rescaling them so that the elements of the n-dimensional output Tensor
    lie in the range [0,1] and sum to 1.
    Softmax is defined as:
    .. math::
        \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
    Shape:
        - Input: :math:`(*)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(*)`, same shape as the input
    Returns:
        a Tensor of the same dimension and shape as the input with
        values in the range [0, 1]
    Arguments:
        dim (int): A dimension along which Softmax will be computed (so every slice
            along dim will sum to 1).
    .. note::
        This module doesn't work directly with NLLLoss,
        which expects the Log to be computed between the Softmax and itself.
        Use `LogSoftmax` instead (it's faster and has better numerical properties).
    Examples::
        >>> m = nn.Softmax(dim=1)
        >>> input = torch.randn(2, 3)
        >>> output = m(input)
    """
    __constants__ = ['dim']
 
    def __init__(self, dim=None):
        super(Softmax, self).__init__()
        self.dim = dim
 
    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None
 
    def forward(self, input):
        return F.softmax(input, self.dim, _stacklevel=5)
 
    def extra_repr(self):
        return 'dim={dim}'.format(dim=self.dim)

2.LogSoftmax其实就是对softmax的结果进行log,即Log(Softmax(x))

class LogSoftmax(Module):
    r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
    input Tensor. The LogSoftmax formulation can be simplified as:
    .. math::
        \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
    Shape:
        - Input: :math:`(*)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(*)`, same shape as the input
    Arguments:
        dim (int): A dimension along which LogSoftmax will be computed.
    Returns:
        a Tensor of the same dimension and shape as the input with
        values in the range [-inf, 0)
    Examples::
        >>> m = nn.LogSoftmax()
        >>> input = torch.randn(2, 3)
        >>> output = m(input)
    """
    __constants__ = ['dim']
 
    def __init__(self, dim=None):
        super(LogSoftmax, self).__init__()
        self.dim = dim
 
    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None
 
    def forward(self, input):
        return F.log_softmax(input, self.dim, _stacklevel=5)

二、代码示例

输入代码

import torch
import torch.nn as nn
import numpy as np
 
batch_size = 4
class_num = 6
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
 
print("inputs:", inputs)

得到大小batch_size为4,类别数为6的向量(可以理解为经过最后一层得到)

tensor([[ 1., 2., 3., 4., 5., 6.],
[ 2., 4., 6., 8., 10., 12.],
[ 3., 6., 9., 12., 15., 18.],
[ 4., 8., 12., 16., 20., 24.]])

接着我们对该向量每一行进行Softmax

Softmax = nn.Softmax(dim=1)
probs = Softmax(inputs)
print("probs:\n", probs)

得到

tensor([[4.2698e-03, 1.1606e-02, 3.1550e-02, 8.5761e-02, 2.3312e-01, 6.3369e-01],
[3.9256e-05, 2.9006e-04, 2.1433e-03, 1.5837e-02, 1.1702e-01, 8.6467e-01],
[2.9067e-07, 5.8383e-06, 1.1727e-04, 2.3553e-03, 4.7308e-02, 9.5021e-01],
[2.0234e-09, 1.1047e-07, 6.0317e-06, 3.2932e-04, 1.7980e-02, 9.8168e-01]])

此外,我们对该向量每一行进行LogSoftmax

LogSoftmax = nn.LogSoftmax(dim=1)
log_probs = LogSoftmax(inputs)
print("log_probs:\n", log_probs)

得到

tensor([[-5.4562e+00, -4.4562e+00, -3.4562e+00, -2.4562e+00, -1.4562e+00, -4.5619e-01],
[-1.0145e+01, -8.1454e+00, -6.1454e+00, -4.1454e+00, -2.1454e+00, -1.4541e-01],
[-1.5051e+01, -1.2051e+01, -9.0511e+00, -6.0511e+00, -3.0511e+00, -5.1069e-02],
[-2.0018e+01, -1.6018e+01, -1.2018e+01, -8.0185e+00, -4.0185e+00, -1.8485e-02]])

验证每一行元素和是否为1

# probs_sum in dim=1
probs_sum = [0 for i in range(batch_size)]
 
for i in range(batch_size):
    for j in range(class_num):
        probs_sum[i] += probs[i][j]
    print(i, "row probs sum:", probs_sum[i])

得到每一行的和,看到确实为1

0 row probs sum: tensor(1.)
1 row probs sum: tensor(1.0000)
2 row probs sum: tensor(1.)
3 row probs sum: tensor(1.)

验证LogSoftmax是对Softmax的结果进行Log

# to numpy
np_probs = probs.data.numpy()
print("numpy probs:\n", np_probs)
 
# np.log()
log_np_probs = np.log(np_probs)
print("log numpy probs:\n", log_np_probs)

得到

numpy probs:
[[4.26977826e-03 1.16064614e-02 3.15496325e-02 8.57607946e-02 2.33122006e-01 6.33691311e-01]
[3.92559559e-05 2.90064461e-04 2.14330270e-03 1.58369839e-02 1.17020354e-01 8.64669979e-01]
[2.90672347e-07 5.83831024e-06 1.17265590e-04 2.35534250e-03 4.73083146e-02 9.50212955e-01]
[2.02340233e-09 1.10474026e-07 6.03167746e-06 3.29318427e-04 1.79801770e-02 9.81684387e-01]]
log numpy probs:
[[-5.4561934e+00 -4.4561934e+00 -3.4561934e+00 -2.4561932e+00 -1.4561933e+00 -4.5619333e-01]
[-1.0145408e+01 -8.1454077e+00 -6.1454072e+00 -4.1454072e+00 -2.1454074e+00 -1.4540738e-01]
[-1.5051069e+01 -1.2051069e+01 -9.0510693e+00 -6.0510693e+00 -3.0510693e+00 -5.1069155e-02]
[-2.0018486e+01 -1.6018486e+01 -1.2018485e+01 -8.0184851e+00 -4.0184855e+00 -1.8485421e-02]]

验证完毕

三、整体代码

import torch
import torch.nn as nn
import numpy as np
 
batch_size = 4
class_num = 6
inputs = torch.randn(batch_size, class_num)
for i in range(batch_size):
    for j in range(class_num):
        inputs[i][j] = (i + 1) * (j + 1)
 
print("inputs:", inputs)
Softmax = nn.Softmax(dim=1)
probs = Softmax(inputs)
print("probs:\n", probs)
 
LogSoftmax = nn.LogSoftmax(dim=1)
log_probs = LogSoftmax(inputs)
print("log_probs:\n", log_probs)
 
# probs_sum in dim=1
probs_sum = [0 for i in range(batch_size)]
 
for i in range(batch_size):
    for j in range(class_num):
        probs_sum[i] += probs[i][j]
    print(i, "row probs sum:", probs_sum[i])
 
# to numpy
np_probs = probs.data.numpy()
print("numpy probs:\n", np_probs)
 
# np.log()
log_np_probs = np.log(np_probs)
print("log numpy probs:\n", log_np_probs)

基于pytorch softmax,logsoftmax 表达

import torch
import numpy as np
input = torch.autograd.Variable(torch.rand(1, 3))

print(input)
print('softmax={}'.format(torch.nn.functional.softmax(input, dim=1)))
print('logsoftmax={}'.format(np.log(torch.nn.functional.softmax(input, dim=1))))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模块restful使用方法实例
Dec 10 Python
详解Python编程中包的概念与管理
Oct 16 Python
Python 中的with关键字使用详解
Sep 11 Python
python fabric实现远程部署
Jan 05 Python
python实现多层感知器MLP(基于双月数据集)
Jan 18 Python
如何使用Python进行OCR识别图片中的文字
Apr 01 Python
python实现图片压缩代码实例
Aug 12 Python
如何基于python操作json文件获取内容
Dec 24 Python
在Sublime Editor中配置Python环境的详细教程
May 03 Python
Python中openpyxl实现vlookup函数的实例
Oct 28 Python
python程序实现BTC(比特币)挖矿的完整代码
Jan 20 Python
详解Python牛顿插值法
May 11 Python
Pytorch中Softmax与LogSigmoid的对比分析
Jun 05 #Python
Pytorch反向传播中的细节-计算梯度时的默认累加操作
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
You might like
PHP __autoload函数(自动载入类文件)的使用方法
2012/02/04 PHP
php fread函数使用方法总结
2019/05/28 PHP
JavaScript 权威指南(第四版) 读书笔记
2009/08/11 Javascript
JS window.opener返回父页面的应用
2009/10/24 Javascript
jquery 查找iframe父级页面元素的实现代码
2011/08/28 Javascript
JavaScript中“过于”犀利地for/in循环使用示例
2013/10/22 Javascript
如何在JavaScript中实现私有属性的写类方式(一)
2013/12/04 Javascript
js的hasownproperty使用示例
2014/03/02 Javascript
JavaScript 作用域链解析
2014/11/13 Javascript
仿iframe效果Aajx文件上传实例
2016/11/18 Javascript
js实现PC端根据IP定位当前城市地理位置
2017/02/22 Javascript
vue.js移动数组位置,同时更新视图的方法
2018/03/08 Javascript
iconfont的三种使用方式详解
2018/08/05 Javascript
vue slot与传参实例代码讲解
2019/04/28 Javascript
Vue对象赋值视图不更新问题及解决方法
2019/06/03 Javascript
JS实现秒杀倒计时特效
2020/01/02 Javascript
Express 配置HTML页面访问的实现
2020/11/01 Javascript
JavaScript实现网页动态生成表格
2020/11/25 Javascript
python类参数self使用示例
2014/02/17 Python
python实现从字符串中找出字符1的位置以及个数的方法
2014/08/25 Python
详解Python中内置的NotImplemented类型的用法
2015/03/31 Python
python函数装饰器用法实例详解
2015/06/04 Python
python扫描proxy并获取可用代理ip的实例
2017/08/07 Python
python将字典内容存入mysql实例代码
2018/01/18 Python
python负载均衡的简单实现方法
2018/02/04 Python
Python3实现腾讯云OCR识别
2018/11/27 Python
python版本五子棋的实现代码
2018/12/11 Python
Python基于机器学习方法实现的电影推荐系统实例详解
2019/06/25 Python
解决Django加载静态资源失败的问题
2019/07/28 Python
Django DRF认证组件流程实现原理详解
2020/08/17 Python
python爬虫调度器用法及实例代码
2020/11/30 Python
Python 的 f-string 可以连接字符串与数字的原因解析
2021/02/20 Python
英国儿童家具专卖店:GLTC
2016/09/24 全球购物
猫途鹰:全球领先的旅游点评社区
2017/04/07 全球购物
党员贯彻十八大精神思想汇报范文
2014/10/25 职场文书
MySQL优化常用的19种有效方法(推荐!)
2022/03/17 MySQL