PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 网络编程起步(Socket发送消息)
Sep 06 Python
python使用分治法实现求解最大值的方法
May 12 Python
Python中encode()方法的使用简介
May 18 Python
python在非root权限下的安装方法
Jan 23 Python
python使用锁访问共享变量实例解析
Feb 08 Python
python算法题 链表反转详解
Jul 02 Python
Django 大文件下载实现过程解析
Aug 01 Python
浅析python表达式4+0.5值的数据类型
Feb 26 Python
Python处理mysql特殊字符的问题
Mar 02 Python
pycharm安装及如何导入numpy
Apr 03 Python
Python机器学习三大件之一numpy
May 10 Python
Django操作cookie的实现
May 26 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
mysql 搜索之简单应用
2007/04/27 PHP
详解php几行代码实现CSV格式文件输出
2017/07/01 PHP
thinkphp5 migrate数据库迁移工具
2018/02/20 PHP
用Div仿showModalDialog模式菜单的效果的代码
2007/03/05 Javascript
JavaScript 对象模型 执行模型
2009/12/06 Javascript
javascript阻止浏览器后退事件防止误操作清空表单
2013/11/22 Javascript
浅析jQuery(function(){})与(function(){})(jQuery)之间的区别
2014/01/09 Javascript
jquery下div 的resize事件示例代码
2014/03/09 Javascript
jquery的trigger和triggerHandler的区别示例介绍
2014/04/20 Javascript
Jquery的基本对象转换和文档加载用法实例
2015/02/25 Javascript
JavaScript实现文本框中默认显示背景图片在获得焦点后消失的方法
2015/07/01 Javascript
jquery 重写 ajax提交并判断权限后 使用load方法报错解决方法
2016/01/19 Javascript
JavaScript数组的一些奇葩行为
2016/01/25 Javascript
探讨AngularJs中ui.route的简单应用
2016/11/16 Javascript
Javascript blur与click冲突解决办法
2017/01/09 Javascript
jQuery简易时光轴实现方法示例
2017/03/13 Javascript
vue.js 使用v-if v-else发现没有执行解决办法
2017/05/15 Javascript
canvas+gif.js打造自己的数字雨头像的示例代码
2017/10/26 Javascript
转换layUI的数据表格中的日期格式方法
2019/09/19 Javascript
在vue项目中promise解决回调地狱和并发请求的问题
2020/11/09 Javascript
Python验证码识别处理实例
2015/12/28 Python
Python中的defaultdict与__missing__()使用介绍
2018/02/03 Python
python requests post多层字典的方法
2018/12/27 Python
Python实现栈的方法详解【基于数组和单链表两种方法】
2020/02/22 Python
解决pyinstaller打包运行程序时出现缺少plotly库问题
2020/06/02 Python
django rest framework使用django-filter用法
2020/07/15 Python
Surfdome西班牙:世界上最受欢迎的生活方式品牌
2019/02/13 全球购物
新东网科技Java笔试题
2012/07/13 面试题
优秀学生干部个人事迹材料
2014/06/02 职场文书
商铺门前三包责任书
2014/07/25 职场文书
创优争先心得体会
2014/09/11 职场文书
2016三严三实专题教育活动心得体会
2016/01/06 职场文书
vue+element ui实现锚点定位
2021/06/29 Vue.js
MySQL 执行数据库更新update操作的时候数据库卡死了
2022/05/02 MySQL
JavaScript中10个Reduce常用场景技巧
2022/06/21 Javascript
Python+DeOldify实现老照片上色功能
2022/06/21 Python