Pytorch损失函数nn.NLLLoss2d()用法说明


Posted in Python onJuly 07, 2020

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

Pytorch损失函数nn.NLLLoss2d()用法说明

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

Pytorch损失函数nn.NLLLoss2d()用法说明

手动计算举个栗子:第一个元素

Pytorch损失函数nn.NLLLoss2d()用法说明

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
探索Python3.4中新引入的asyncio模块
Apr 08 Python
Python基于PycURL实现POST的方法
Jul 25 Python
Python的Scrapy爬虫框架简单学习笔记
Jan 20 Python
Java多线程编程中ThreadLocal类的用法及深入
Jun 21 Python
基于Python 的进程管理工具supervisor使用指南
Sep 18 Python
Python程序运行原理图文解析
Feb 10 Python
基于Python pip用国内镜像下载的方法
Jun 12 Python
python二维码操作:对QRCode和MyQR入门详解
Jun 24 Python
kafka监控获取指定topic的消息总量示例
Dec 23 Python
Python中用pyinstaller打包时的图标问题及解决方法
Feb 17 Python
python关于变量名的基础知识点
Mar 03 Python
Python如何使用paramiko模块连接linux
Mar 18 Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 #Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
You might like
PHP命名空间namespace及use的简单用法分析
2018/08/03 PHP
php高清晰度无损图片压缩功能的实现代码
2018/12/09 PHP
详解php中生成标准uuid(guid)的方法
2019/04/28 PHP
laravel框架模型中非静态方法也能静态调用的原理分析
2019/11/23 PHP
javascript 页面只自动刷新一次
2009/07/10 Javascript
javascript 用原型继承来实现对象系统
2010/03/22 Javascript
JQuery live函数
2010/12/24 Javascript
jQuery学习笔记之jQuery.fn.init()的参数分析
2014/06/09 Javascript
js鼠标悬浮出现遮罩层的方法
2015/01/28 Javascript
jQuery插件dataTables添加序号列的方法
2016/07/06 Javascript
angular2中使用第三方js库的实例
2018/02/26 Javascript
axios发送post请求springMVC接收不到参数的解决方法
2018/03/05 Javascript
Vue 重置组件到初始状态的方法示例
2018/10/10 Javascript
Vuex mutitons和actions初使用详解
2019/03/04 Javascript
Vue商品控件与购物车联动效果的实例代码
2019/07/21 Javascript
浅谈Layui的eleTree树式选择器使用方法
2019/09/25 Javascript
js回调函数仿360开机
2019/12/26 Javascript
vue-socket.io跨域问题有效解决方法
2020/02/11 Javascript
JS实现拖拽元素时与另一元素碰撞检测
2020/08/27 Javascript
Vue组件通信$attrs、$listeners实现原理解析
2020/09/03 Javascript
零基础写python爬虫之使用Scrapy框架编写爬虫
2014/11/07 Python
Python面向对象编程中的类和对象学习教程
2015/03/30 Python
Python实现计算最小编辑距离
2016/03/17 Python
Python开发的十个小贴士和技巧及长常犯错误
2018/09/27 Python
详解python之heapq模块及排序操作
2019/04/04 Python
python输出决策树图形的例子
2019/08/09 Python
Python学习笔记之字符串和字符串方法实例详解
2019/08/22 Python
selenium中get_cookies()和add_cookie()的用法详解
2020/01/06 Python
Python3安装模块报错Microsoft Visual C++ 14.0 is required的解决方法
2020/07/28 Python
纯CSS3大转盘抽奖示例代码(响应式、可配置)
2017/01/13 HTML / CSS
财务部岗位职责
2013/11/19 职场文书
入党申请人的自我鉴定
2013/12/01 职场文书
大学生职业生涯规划书参考模板
2014/03/05 职场文书
红与黑读书笔记
2015/06/29 职场文书
Go语言实现一个简单的并发聊天室的项目实战
2022/03/18 Golang
vue特效之翻牌动画
2022/04/20 Vue.js