PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现从百度API获取天气的方法
Mar 11 Python
利用Python绘制数据的瀑布图的教程
Apr 07 Python
python3+PyQt5实现使用剪贴板做复制与粘帖示例
Jan 24 Python
Python嵌套列表转一维的方法(压平嵌套列表)
Jul 03 Python
Python爬虫之正则表达式的使用教程详解
Oct 25 Python
选择Python写网络爬虫的优势和理由
Jul 07 Python
Django的models模型的具体使用
Jul 15 Python
django使用haystack调用Elasticsearch实现索引搜索
Jul 24 Python
python3 深浅copy对比详解
Aug 12 Python
Python Tkinter实例——模拟掷骰子
Oct 24 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
Jan 27 Python
windows安装python超详细图文教程
May 21 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
解决phpmyadmin 乱码,支持gb2312和utf-8
2006/11/20 PHP
PHP 文件编程综合案例-文件上传的实现
2013/07/03 PHP
php将url地址转化为完整的a标签链接代码(php为url地址添加a标签)
2014/01/17 PHP
Android App中DrawerLayout抽屉效果的菜单编写实例
2016/03/21 PHP
深入理解PHP之OpCode原理详解
2016/06/01 PHP
thinkPHP框架实现多表查询的方法
2018/06/14 PHP
PHP 实现重载
2021/03/09 PHP
JavaScript Undefined,Null类型和NaN值区别
2008/10/22 Javascript
js 页面输出值
2008/11/30 Javascript
JavaScript 题型问答有答案参考
2010/02/17 Javascript
Prototype源码浅析 String部分(二)
2012/01/16 Javascript
JQuery实现简单的图片滑动切换特效
2015/11/22 Javascript
详解nodejs模板引擎制作
2017/06/14 NodeJs
JS失效 提示HTML1114: (UNICODE 字节顺序标记)的代码页 utf-8 覆盖(META 标记)的冲突的代码页 utf-8
2017/06/23 Javascript
vue插件vue-resource的使用笔记(小结)
2017/08/04 Javascript
深入理解Vue官方文档梳理之全局API
2017/11/22 Javascript
vue.js添加一些触摸事件以及安装fastclick的实例
2018/08/28 Javascript
jQuery使用bind动态绑定事件无效的处理方法
2018/12/11 jQuery
layui 地区三级联动 form select 渲染的实例
2019/09/27 Javascript
Vue项目vscode 安装eslint插件的方法(代码自动修复)
2020/04/15 Javascript
jQuery 选择器用法实例分析【prev + next】
2020/05/22 jQuery
python实现定制交互式命令行的方法
2014/07/03 Python
python多重继承实例
2014/10/11 Python
tensorflow 中对数组元素的操作方法
2018/07/27 Python
python3的pip路径在哪
2020/06/23 Python
在线购买世界上最好的酒:BoozeBud
2018/06/07 全球购物
英国川宁茶官方网站:Twinings茶
2019/05/21 全球购物
幼儿园实习自我鉴定
2013/12/15 职场文书
《狼》教学反思
2014/03/02 职场文书
啤酒节策划方案
2014/05/28 职场文书
环境保护标语
2014/06/20 职场文书
商场租赁意向书
2014/07/30 职场文书
安全月宣传标语
2014/10/07 职场文书
python内置进制转换函数的操作
2021/06/02 Python
python数字图像处理之图像的批量处理
2022/06/28 Python
Redis过期数据是否会被立马删除
2022/07/23 Redis