PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之二叉树的遍历实例
Apr 29 Python
selenium+python自动化测试环境搭建步骤
Jun 03 Python
python set内置函数的具体使用
Jul 02 Python
使用Rasterio读取栅格数据的实例讲解
Nov 26 Python
numpy中三维数组中加入元素后的位置详解
Nov 28 Python
python单例设计模式实现解析
Jan 07 Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 Python
翻转数列python实现,求前n项和,并能输出整个数列的案例
May 03 Python
python实现学生信息管理系统源码
Feb 22 Python
Python+Tkinter制作专属图形化界面
Apr 01 Python
详解OpenCV曝光融合
Apr 29 Python
python神经网络Xception模型
May 06 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
php实现俄罗斯乘法实例
2015/03/07 PHP
PHP CURL 多线程操作代码实例
2015/05/13 PHP
详解yii2实现分库分表的方案与思路
2017/02/03 PHP
php模式设计之观察者模式应用实例分析
2019/09/25 PHP
laravel框架模型和数据库基础操作实例详解
2020/01/25 PHP
动态添加js事件实现代码
2009/03/12 Javascript
js获取判断上传文件后缀名的示例代码
2014/02/19 Javascript
推荐9款炫酷的基于jquery的页面特效
2014/12/07 Javascript
javascript实现动态改变层大小的方法
2015/05/14 Javascript
JavaScript编程中的Promise使用大全
2015/07/28 Javascript
jQuery带进度条全屏图片轮播特效代码分享
2020/06/28 Javascript
在ASP.NET MVC项目中使用RequireJS库的用法示例
2016/02/15 Javascript
js轮播图代码分享
2016/07/14 Javascript
js 判断数据类型的几种方法
2017/01/13 Javascript
浅谈es6语法 (Proxy和Reflect的对比)
2017/10/24 Javascript
Vue封装Swiper实现图片轮播效果
2018/02/06 Javascript
nodejs更新package.json中的dependencies依赖到最新版本的方法
2018/10/10 NodeJs
spring+angular实现导出excel的实现代码
2019/02/27 Javascript
javascript移动端 电子书 翻页效果实现代码
2019/09/07 Javascript
基于Nuxt.js项目的服务端性能优化与错误检测(容错处理)
2019/10/23 Javascript
Vue自定义指令结合阿里云OSS优化图片的实现方法
2019/11/12 Javascript
基于JavaScript实现控制下拉列表
2020/05/08 Javascript
vue element-ui中table合计指定列求和实例
2020/11/02 Javascript
[01:02:45]完美世界DOTA2联赛 LBZS vs Forest 第三场 11.07
2020/11/09 DOTA
浅谈Python程序与C++程序的联合使用
2015/04/07 Python
Python 数据结构之旋转链表
2017/02/25 Python
PyCharm 设置SciView工具窗口的方法
2019/01/15 Python
python中Array和DataFrame相互转换的实例讲解
2021/02/03 Python
使用CSS3实现一个3D相册效果实例
2016/12/03 HTML / CSS
简单几步用纯CSS3实现3D翻转效果
2019/01/17 HTML / CSS
为智能设备设计个性化保护套网站:caseable
2017/01/05 全球购物
无畏的旅行:Intrepid Travel
2017/12/20 全球购物
ESDlife健康生活易:身体检查预订、搜寻及比较
2019/05/10 全球购物
英语翻译系毕业生求职信
2013/09/29 职场文书
导游词之河姆渡遗址博物馆
2019/10/10 职场文书
Python requests用法和django后台处理详解
2022/03/19 Python