PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计数排序和基数排序算法实例
Apr 25 Python
Tensorflow实现卷积神经网络用于人脸关键点识别
Mar 05 Python
浅谈python3发送post请求参数为空的情况
Dec 28 Python
python dict 相同key 合并value的实例
Jan 21 Python
python使用thrift教程的方法示例
Mar 21 Python
500行Python代码打造刷脸考勤系统
Jun 03 Python
Python实现12306火车票抢票系统
Jul 04 Python
PyTorch中Tensor的维度变换实现
Aug 18 Python
python+tifffile之tiff文件读写方式
Jan 13 Python
Django使用Celery加redis执行异步任务的实例内容
Feb 20 Python
pyautogui自动化控制鼠标和键盘操作的步骤
Apr 01 Python
flask框架中的cookie和session使用
Jan 31 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
Zend Framework教程之配置文件application.ini解析
2016/03/10 PHP
php实现按天数、星期、月份查询的搜索框
2016/05/02 PHP
Yii 2.0在Grid中格式化时间方法示例
2017/06/06 PHP
PHP大文件切割上传并带进度条功能示例
2019/07/01 PHP
js实现的切换面板实例代码
2013/06/17 Javascript
JS 去除Array中的null值示例代码
2013/11/20 Javascript
Ext修改GridPanel数据和字体颜色、css属性等
2014/06/13 Javascript
javascript 获取浏览器版本
2015/01/21 Javascript
基于JavaScript实现简单的随机抽奖小程序
2016/01/05 Javascript
浅谈jQuery this和$(this)的区别及获取$(this)子元素对象的方法
2016/11/29 Javascript
JS使用cookie实现只出现一次的广告代码效果
2017/04/22 Javascript
React学习笔记之高阶组件应用
2018/06/02 Javascript
JS简单表单验证功能完整示例
2020/01/26 Javascript
Vue路由切换页面不更新问题解决方案
2020/07/10 Javascript
[02:40]DOTA2英雄基础教程 巨牙海民
2013/12/23 DOTA
Python入门及进阶笔记 Python 内置函数小结
2014/08/09 Python
Python re模块介绍
2014/11/30 Python
python实现解数独程序代码
2017/04/12 Python
Flask数据库迁移简单介绍
2017/10/24 Python
Python文本特征抽取与向量化算法学习
2017/12/22 Python
Python实现加载及解析properties配置文件的方法
2018/03/29 Python
python浪漫表白源码
2019/04/05 Python
使用Python将字符串转换为格式化的日期时间字符串
2019/09/01 Python
python 连续不等式语法糖实例
2020/04/15 Python
CSS伪类与CSS伪元素的区别及由来具体说明
2012/12/07 HTML / CSS
HTML5实现页面切换激活的PageVisibility API使用初探
2016/05/13 HTML / CSS
瑞士灯具购物网站:Lampenwelt.ch
2018/07/08 全球购物
飞利信loadrunner和软件测试笔试题
2012/09/22 面试题
销售行政专员职责
2014/01/03 职场文书
上班离岗检讨书
2014/01/27 职场文书
网络文明传播志愿者活动方案
2014/08/20 职场文书
大学生学习新党章思想汇报
2014/10/25 职场文书
民政工作个人总结
2015/02/28 职场文书
教师聘用意向书
2015/05/11 职场文书
iSCSI服务器CHAP双向认证配置
2022/04/01 Servers
uniapp开发打包多端应用完整方法指南
2022/12/24 Javascript