PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串对其居中显示的方法
Jul 11 Python
Python 转义字符详细介绍
Mar 21 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
详解Python 字符串相似性的几种度量方法
Aug 29 Python
用python的turtle模块实现给女票画个小心心
Nov 23 Python
opencv之为图像添加边界的方法示例
Dec 26 Python
基于matplotlib xticks用法详解
Apr 16 Python
解决matplotlib.pyplot在Jupyter notebook中不显示图像问题
Apr 22 Python
PyQt5实现登录页面
May 30 Python
如何让PyQt5中QWebEngineView与JavaScript交互
Oct 21 Python
python基于tkinter实现gif录屏功能
May 19 Python
为了顺利买到演唱会的票用Python制作了自动抢票的脚本
Oct 16 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
深入了解PHP类Class的概念
2012/06/14 PHP
PHP字符串中特殊符号的过滤方法介绍
2014/02/18 PHP
php校验表单检测字段是否为空的方法
2015/03/20 PHP
微信支付开发发货通知实例
2016/07/12 PHP
PHP将字符串首字母大小写转换的实例
2017/01/21 PHP
PHP 记录访客的浏览信息方法
2018/01/29 PHP
基于PHP实现用户在线状态检测
2020/11/10 PHP
js重写alert控件(适合学习js的新手朋友)
2014/08/24 Javascript
jQuery实现图片轮播特效代码分享
2015/09/15 Javascript
BootStrap CSS全局样式和表格样式源码解析
2017/01/20 Javascript
JS中Attr的用法详解
2017/10/09 Javascript
Bootstrap Table列宽拖动的方法
2018/08/15 Javascript
发布Angular应用至生产环境的方法
2018/12/10 Javascript
vue实现Excel文件的上传与下载功能的两种方式
2019/06/28 Javascript
小程序分页实践之编写可复用分页组件
2019/07/18 Javascript
vue实现分页加载效果
2019/12/24 Javascript
Python实现将Excel转换为json的方法示例
2017/08/05 Python
Python实现批量执行同目录下的py文件方法
2019/01/11 Python
Python实现判断一个整数是否为回文数算法示例
2019/03/02 Python
Django文件存储 默认存储系统解析
2019/08/02 Python
浅析python内置模块collections
2019/11/15 Python
基于python的列表list和集合set操作
2019/11/24 Python
Python迭代器模块itertools使用原理解析
2019/12/11 Python
python使用列表的最佳方案
2020/08/12 Python
安装不同版本的tensorflow与models方法实现
2021/02/20 Python
美国在线购物频道:Shop LC
2019/04/21 全球购物
大学生活自我评价
2014/04/09 职场文书
副科级后备干部考察材料
2014/05/15 职场文书
军训拉歌口号
2014/06/13 职场文书
优秀应届本科生求职信
2014/07/19 职场文书
学校组织向国旗敬礼活动方案(中小学适用)
2014/09/27 职场文书
说谎欺骗人检讨书300字
2014/11/18 职场文书
办公室主任岗位竞聘书
2015/09/15 职场文书
2016春季校长开学典礼致辞
2015/11/26 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
Win10/Win11 任务栏替换成经典样式
2022/04/19 数码科技