Pytorch损失函数nn.NLLLoss2d()用法说明


Posted in Python onJuly 07, 2020

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

Pytorch损失函数nn.NLLLoss2d()用法说明

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

Pytorch损失函数nn.NLLLoss2d()用法说明

手动计算举个栗子:第一个元素

Pytorch损失函数nn.NLLLoss2d()用法说明

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 测试实现方法
Dec 24 Python
Python中统计函数运行耗时的方法
May 05 Python
在Python的Django框架中创建语言文件
Jul 27 Python
基于python3 类的属性、方法、封装、继承实例讲解
Sep 19 Python
Django学习笔记之ORM基础教程
Mar 27 Python
Python实现的简单计算器功能详解
Aug 25 Python
python微信好友数据分析详解
Nov 19 Python
pandas中read_csv的缺失值处理方式
Dec 19 Python
python opencv 实现对图像边缘扩充
Jan 19 Python
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
Feb 07 Python
Python中qutip用法示例详解
Oct 02 Python
Python lxml库的简单介绍及基本使用讲解
Dec 22 Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 #Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
You might like
syphon 虹吸式咖啡冲泡冲煮倒水的得与失
2021/03/03 冲泡冲煮
压力如何影响浓缩咖啡品质
2021/03/03 咖啡文化
PHP函数篇之掌握ord()与chr()函数应用
2011/12/05 PHP
PHP JS Ip地址及域名格式检测代码
2013/09/27 PHP
高性能PHP框架Symfony2经典入门教程
2014/07/08 PHP
解决Extjs上传图片无法预览的解决方法
2012/03/22 Javascript
Ext GridPanel加载完数据后进行操作示例代码
2014/06/17 Javascript
JavaScript计算两个日期时间段内日期的方法
2015/03/16 Javascript
javascript运动效果实例总结(放大缩小、滑动淡入、滚动)
2016/01/08 Javascript
使用getBoundingClientRect方法实现简洁的sticky组件的方法
2016/03/22 Javascript
JS中取二维数组中最大值的方法汇总
2016/04/17 Javascript
详解JS对象封装的常用方式
2016/12/30 Javascript
nodejs的压缩文件模块archiver用法示例
2017/01/18 NodeJs
jQuery扇形定时器插件pietimer使用方法详解
2017/07/18 jQuery
JS实现二维数组横纵列转置的方法
2018/04/17 Javascript
vue拦截器实现统一token,并兼容IE9验证功能
2018/04/26 Javascript
微信小程序BindTap快速连续点击目标页面跳转多次问题处理
2019/04/08 Javascript
详解axios中封装使用、拦截特定请求、判断所有请求加载完毕)
2019/04/09 Javascript
详解Webpack如何引入CDN链接来优化编译后的体积
2019/06/21 Javascript
Python 模拟购物车的实例讲解
2017/09/11 Python
Python批量生成幻影坦克图片实例代码
2019/06/04 Python
使用WingPro 7 设置Python路径的方法
2019/07/24 Python
浅谈keras使用中val_acc和acc值不同步的思考
2020/06/18 Python
matplotlib基础绘图命令之imshow的使用
2020/08/13 Python
关于python中导入文件到list的问题
2020/10/31 Python
python3中编码获取网页的实例方法
2020/11/16 Python
使用索引有什么好处
2016/07/27 面试题
毕业生实习鉴定
2013/12/11 职场文书
运动会通讯稿100字
2014/01/31 职场文书
授权委托书怎么写
2014/09/25 职场文书
漂亮妈妈观后感
2015/06/08 职场文书
请病假条范文
2015/08/17 职场文书
大学生各类奖学金申请书
2019/06/24 职场文书
nginx反向代理配置去除前缀案例教程
2021/07/26 Servers
PYTHON 使用 Pandas 删除某列指定值所在的行
2022/04/28 Python
Apache SeaTunnel实现 非CDC数据抽取
2022/05/20 Servers