Pytorch损失函数nn.NLLLoss2d()用法说明


Posted in Python onJuly 07, 2020

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

Pytorch损失函数nn.NLLLoss2d()用法说明

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

Pytorch损失函数nn.NLLLoss2d()用法说明

手动计算举个栗子:第一个元素

Pytorch损失函数nn.NLLLoss2d()用法说明

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Heroku云平台上部署Python的Django框架的教程
Apr 20 Python
bpython 功能强大的Python shell
Feb 16 Python
Python端口扫描简单程序
Nov 10 Python
Python中二维列表如何获取子区域元素的组成
Jan 19 Python
python如何拆分含有多种分隔符的字符串
Mar 20 Python
计算机二级python学习教程(1) 教大家如何学习python
May 16 Python
Python 动态导入对象,importlib.import_module()的使用方法
Aug 28 Python
python生成器推导式用法简单示例
Oct 08 Python
Series和DataFrame使用简单入门
Nov 13 Python
pytorch梯度剪裁方式
Feb 04 Python
Python调用接口合并Excel表代码实例
Mar 31 Python
python获得命令行输入的参数的两种方式
Nov 02 Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 #Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
You might like
用PHP开发GUI
2006/10/09 PHP
简单的PHP图片上传程序
2008/03/27 PHP
详解PHP对象的串行化与反串行化
2016/01/24 PHP
Yii2创建控制器(createController)方法详解
2016/07/23 PHP
php实现留言板功能(代码详解)
2017/03/28 PHP
PHP基于MySQLI函数封装的数据库连接工具类【定义与用法】
2017/08/11 PHP
强大的jquery插件jqeuryUI做网页对话框效果!简单
2011/04/14 Javascript
js下将阿拉伯数字每三位一逗号分隔(如:15000000转化为15,000,000)
2014/06/02 Javascript
用js代码和插件实现wordpress雪花飘落效果的四种方法
2014/12/15 Javascript
浅谈Javascript中的Function与Object
2015/01/26 Javascript
jquery图片切换插件
2015/03/16 Javascript
jQuery $.each遍历对象、数组用法实例
2015/04/16 Javascript
javascript每日必学之继承
2016/02/23 Javascript
jQuery EasyUI框架中的Datagrid数据表格组件结构详解
2016/06/09 Javascript
bootstrap网格系统使用方法解析
2017/01/13 Javascript
深入理解Vuex 模块化(module)
2017/09/26 Javascript
利用百度地图API获取当前位置信息的实例
2017/11/06 Javascript
详解使用vuex进行菜单管理
2017/12/21 Javascript
JS+HTML5 Canvas实现简单的写字板功能示例
2018/08/30 Javascript
Angular 多级路由实现登录页面跳转(小白教程)
2019/11/19 Javascript
Node.js API详解之 string_decoder用法实例分析
2020/04/29 Javascript
[03:21]【TI9纪实】Old Boys
2019/08/23 DOTA
Python学习笔记之os模块使用总结
2014/11/03 Python
简单谈谈python中的Queue与多进程
2016/08/25 Python
python使用pandas实现数据分割实例代码
2018/01/25 Python
解决Python中list里的中文输出到html模板里的问题
2018/12/17 Python
python如何提取英语pdf内容并翻译
2020/03/03 Python
Python 高效编程技巧分享
2020/09/10 Python
CSS3动画之流彩文字效果+图片模糊效果+边框伸展效果实现代码合集
2017/08/18 HTML / CSS
印尼第一大家居、生活和家具电子商务:Ruparupa
2019/11/25 全球购物
马来西亚排名第一的宠物用品店:Pets Wonderland
2020/04/16 全球购物
教师年终个人自我评价
2013/10/04 职场文书
企业总经理岗位职责
2014/02/13 职场文书
理工大学毕业生自荐信范文
2014/02/22 职场文书
运动会加油稿50字
2015/07/21 职场文书
python opencv旋转图片的使用方法
2021/06/04 Python