PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
使用优化器来提升Python程序的执行效率的教程
Apr 02 Python
横向对比分析Python解析XML的四种方式
Mar 30 Python
KMP算法精解及其Python版的代码示例
Jun 01 Python
Python实现字典去除重复的方法示例
Jul 31 Python
Python实现改变与矩形橡胶的线条的颜色代码示例
Jan 05 Python
在Python 不同级目录之间模块的调用方法
Jan 19 Python
在Python中获取操作系统的进程信息
Aug 27 Python
python飞机大战pygame游戏背景设计详解
Dec 17 Python
pytorch 中pad函数toch.nn.functional.pad()的用法
Jan 08 Python
Python itertools.product方法代码实例
Mar 27 Python
移除Selenium中window.navigator.webdriver值
Jun 10 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
php方法调用模式与函数调用模式简例
2011/09/20 PHP
PHP超全局数组(Superglobals)介绍
2015/07/01 PHP
PHP对XML内容进行修改和删除实例代码
2016/10/26 PHP
一行代码告别document.getElementById
2012/06/01 Javascript
Jquery操作下拉框(DropDownList)实现取值赋值
2013/08/13 Javascript
javascript根据像素点取位置示例
2014/01/27 Javascript
jquery实现下拉菜单的二级联动利用json对象从DB取值显示联动
2014/03/27 Javascript
js面向对象编程之如何实现方法重载
2014/07/02 Javascript
javascript实现点击单选按钮链接转向对应网址的方法
2015/08/12 Javascript
使用javascript函数编写简单银行取钱存钱流程
2018/05/26 Javascript
使用vue-cli脚手架工具搭建vue-webpack项目
2019/01/14 Javascript
vue把输入框的内容添加到页面的实例讲解
2019/11/11 Javascript
微信小程序新闻网站详情页实例代码
2020/01/10 Javascript
js生成1到100的随机数最简单的实现方法
2020/02/07 Javascript
Vue结合路由配置递归实现菜单栏功能
2020/06/16 Javascript
[01:22:10]Ti4 循环赛第二日 DK vs Empire
2014/07/11 DOTA
numpy数组拼接简单示例
2017/12/15 Python
python 查找文件名包含指定字符串的方法
2018/06/05 Python
PyQt5 QListWidget选择多项并返回的实例
2019/06/17 Python
python爬虫selenium和phantomJs使用方法解析
2019/08/08 Python
python 发送json数据操作实例分析
2019/10/15 Python
python-web根据元素属性进行定位的方法
2019/12/13 Python
Python抓包程序mitmproxy安装和使用过程图解
2020/03/02 Python
Scrapy实现模拟登录的示例代码
2021/02/21 Python
中间件分为哪几类
2012/03/14 面试题
初二政治教学反思
2014/01/12 职场文书
应聘编辑自荐信范文
2014/03/12 职场文书
敬老院院长事迹材料
2014/05/21 职场文书
开工典礼策划方案
2014/05/23 职场文书
小学课外活动总结
2014/07/09 职场文书
乡镇党的群众路线教育实践活动剖析材料
2014/10/09 职场文书
2014年保管员工作总结
2014/11/18 职场文书
项目投资意向书范本
2015/05/09 职场文书
jquery插件实现搜索历史
2021/04/24 jQuery
vue组件的路由高亮问题解决方法
2021/05/11 Vue.js
JavaScript设计模式之原型模式详情
2022/06/21 Javascript