Pytorch损失函数nn.NLLLoss2d()用法说明


Posted in Python onJuly 07, 2020

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

Pytorch损失函数nn.NLLLoss2d()用法说明

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

Pytorch损失函数nn.NLLLoss2d()用法说明

手动计算举个栗子:第一个元素

Pytorch损失函数nn.NLLLoss2d()用法说明

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python构建Hopfield网络的教程
Apr 14 Python
用Python的Django框架来制作一个RSS阅读器
Jul 22 Python
python3实现暴力穷举博客园密码
Jun 19 Python
Python基于正则表达式实现文件内容替换的方法
Aug 30 Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 Python
python selenium 获取标签的属性值、内容、状态方法
Jun 22 Python
Python中logging实例讲解
Jan 17 Python
python redis 删除key脚本的实例
Feb 19 Python
win7下 python3.6 安装opencv 和 opencv-contrib-python解决 cv2.xfeatures2d.SIFT_create() 的问题
Oct 24 Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 Python
python GUI库图形界面开发之PyQt5单行文本框控件QLineEdit详细使用方法与实例
Feb 27 Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 #Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
You might like
PHP获取搜索引擎关键字来源的函数(支持百度和谷歌等搜索引擎)
2012/10/03 PHP
php抽象方法和抽象类实例分析
2016/12/07 PHP
javascript 冒号 使用说明
2009/06/06 Javascript
用Javascript实现锚点(Anchor)间平滑跳转
2009/09/08 Javascript
JS控制文本框textarea输入字数限制的方法
2013/06/17 Javascript
JS实现点击图片在当前页面放大并可关闭的漂亮效果
2013/10/18 Javascript
jQuery实现冻结表格行和列
2015/04/29 Javascript
详细介绍jQuery.outerWidth() 函数具体用法
2015/07/20 Javascript
javascript同步服务器时间和同步倒计时小技巧
2015/09/24 Javascript
各式各样的导航条效果css3结合jquery代码实现
2016/09/17 Javascript
微信小程序 地图(map)实例详解
2016/11/16 Javascript
JavaScript中this的用法实例分析
2016/12/19 Javascript
详解Vuejs2.0 如何利用proxyTable实现跨域请求
2017/08/03 Javascript
javaScript手机号码校验工具类PhoneUtils详解
2017/12/08 Javascript
JavaScript栈和队列相关操作与实现方法详解
2018/12/07 Javascript
抖音上用记事本编写爱心小程序教程
2019/04/17 Javascript
微信小程序云开发(数据库)详解
2019/05/17 Javascript
windows下create-react-app 升级至3.3.1版本踩坑记
2020/02/17 Javascript
webstorm建立vue-cli脚手架的傻瓜式教程
2020/09/22 Javascript
Python机器学习库scikit-learn安装与基本使用教程
2018/06/25 Python
python看某个模块的版本方法
2018/10/16 Python
python调用java的jar包方法
2018/12/15 Python
Python考拉兹猜想输出序列代码实践
2019/07/05 Python
通过PHP与Python代码对比的语法差异详解
2019/07/10 Python
matplotlib命令与格式之tick坐标轴日期格式(设置日期主副刻度)
2019/08/06 Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
2020/01/14 Python
python中68个内置函数的总结与介绍
2020/02/24 Python
使用matlab 判断两个矩阵是否相等的实例
2020/05/11 Python
CSS3 特效范例整理
2011/08/22 HTML / CSS
最新大学生创业计划书写作攻略
2014/04/02 职场文书
钱学森观后感
2015/06/04 职场文书
运动会开幕式致辞
2015/07/29 职场文书
遗嘱范文
2015/08/07 职场文书
2019军训心得体会
2019/06/27 职场文书
《月歌。》宣布制作10周年纪念剧场版《RABBITS KINGDOM THE MOVIE》
2022/04/02 日漫
SQL中的连接查询详解
2022/06/21 SQL Server