PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现2014火车票查询代码分享
Jan 10 Python
在python的WEB框架Flask中使用多个配置文件的解决方法
Apr 18 Python
Python写入数据到MP3文件中的方法
Jul 10 Python
详解Python迭代和迭代器
Mar 28 Python
简单讲解Python编程中namedtuple类的用法
Jun 21 Python
用Python实现KNN分类算法
Dec 22 Python
PyCharm 配置远程python解释器和在本地修改服务器代码
Jul 23 Python
python 爬虫百度地图的信息界面的实现方法
Oct 27 Python
flask框架自定义url转换器操作详解
Jan 25 Python
python中count函数简单的实例讲解
Feb 06 Python
jupyter notebook清除输出方式
Apr 10 Python
pandas 按日期范围筛选数据的实现
Feb 20 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
php中的三元运算符使用说明
2011/07/03 PHP
第二章 PHP入门基础之php代码写法
2011/12/30 PHP
PHP基于imagick扩展实现合成图片的两种方法【附imagick扩展下载】
2017/11/14 PHP
PHP7生产环境队列Beanstalkd用法详解
2020/05/19 PHP
基于jquery的jqDnR拖拽溢出的修改
2011/02/12 Javascript
jQuery弹出层始终垂直居中相对于屏幕或当前窗口
2013/04/01 Javascript
jQuery实现强制cookie过期方法汇总
2015/05/22 Javascript
jQuery双向列表选择器DIV模拟版
2016/11/01 Javascript
详解JavaScript常量定义
2017/01/03 Javascript
Element中的Cascader(级联列表)动态加载省\市\区数据的方法
2019/03/27 Javascript
详解从0开始搭建微信小程序(前后端)的全过程
2019/04/15 Javascript
javaScript把其它类型转换为Number类型
2019/10/13 Javascript
JavaScript禁止右击保存图片,禁止拖拽图片的实现代码
2020/04/28 Javascript
Vue.js使用axios动态获取response里的data数据操作
2020/09/08 Javascript
react ant Design手动设置表单的值操作
2020/10/31 Javascript
python getopt 参数处理小示例
2009/06/09 Python
使用 Python 获取 Linux 系统信息的代码
2014/07/13 Python
Python库urllib与urllib2主要区别分析
2014/07/13 Python
Python调用C# Com dll组件实战教程
2017/10/12 Python
python实现协同过滤推荐算法完整代码示例
2017/12/15 Python
Python使用base64模块进行二进制数据编码详解
2018/01/11 Python
BONIA官方网站:国际奢侈品牌和皮革专家
2016/11/27 全球购物
函授毕业自我鉴定
2013/12/19 职场文书
趣味游戏活动方案
2014/02/07 职场文书
趣味比赛活动方案
2014/02/15 职场文书
交通事故委托书范本精选
2014/10/04 职场文书
对照四风自我剖析材料
2014/10/07 职场文书
党员批评与自我批评思想汇报
2014/10/08 职场文书
小学优秀班主任材料
2014/12/17 职场文书
公司保洁员岗位职责
2015/02/13 职场文书
2019年大学生职业生涯规划书
2019/03/25 职场文书
如何写一份成功的商业计划书
2019/06/25 职场文书
教您怎么制定西餐厅运营方案 ?
2019/07/05 职场文书
100句拼搏进取的名言警句,值得一读!
2019/10/07 职场文书
NodeJs内存占用过高的排查实战记录
2021/05/10 NodeJs
Java各种比较对象的方式的对比总结
2021/06/20 Java/Android