PyTorch中topk函数的用法详解


Posted in Python onJanuary 02, 2020

听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。

用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

largest:如果为True,按照大到小排序; 如果为False,按照小到大排序

sorted:返回的结果按照顺序返回

out:可缺省,不要

topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。

假设一个PyTorch中topk函数的用法详解,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)

print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
    [-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
    [-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
    [-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3],  #【0,0】代表 第一个样本最可能属于第一类别
    [2],  # 【1, 0】代表第二个样本最可能属于第二类别
    [4],
    [2]])
# indices_max等于indices
tensor([[True],
    [True],
    [True],
    [True]])

现在在尝试一下k=2

import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
    [-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
    [ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
    [ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
    [2, 1],
    [2, 1],
    [4, 1]])

可以发现indices的shape变成了【4, k】,k=2。

其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。

大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。

以上这篇PyTorch中topk函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python strip lstrip rstrip使用方法
Sep 06 Python
使用Python编写一个模仿CPU工作的程序
Apr 16 Python
Python采用Django开发自己的博客系统
Sep 29 Python
python利用sklearn包编写决策树源代码
Dec 21 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 Python
利用TensorFlow训练简单的二分类神经网络模型的方法
Mar 05 Python
tensorflow实现简单的卷积网络
May 24 Python
详解django中使用定时任务的方法
Sep 27 Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 Python
Python-split()函数实例用法讲解
Dec 18 Python
为了顺利买到演唱会的票用Python制作了自动抢票的脚本
Oct 16 Python
Python尝试实现蒙特卡罗模拟期权定价
Apr 21 Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
Python基础之函数基本用法与进阶详解
Jan 02 #Python
You might like
PHP自动更新新闻DIY
2006/10/09 PHP
PHP 服务器配置(使用Apache及IIS两种方法)
2009/06/01 PHP
详解php比较操作符的安全问题
2015/12/03 PHP
PHP new static 和 new self详解
2017/02/19 PHP
PHP排序二叉树基本功能实现方法示例
2018/05/26 PHP
Thinkphp5.0 框架的请求方式与响应方式分析
2019/10/14 PHP
学习YUI.Ext基础第一天
2007/03/10 Javascript
[全兼容哦]--实用、简洁、炫酷的页面转入效果loing
2007/05/07 Javascript
jQuery-Tools-overlay 使用介绍
2012/07/14 Javascript
jquery滚动组件(vticker.js)实现页面动态数据的滚动效果
2013/07/03 Javascript
javascript实现3D变换的立体圆圈实例
2015/08/06 Javascript
js实现tab切换效果实例
2015/09/16 Javascript
JQuery点击事件回到页面顶部效果的实现代码
2016/05/24 Javascript
Bootstrap3.0学习教程之JS折叠插件
2016/05/27 Javascript
jQuery插件echarts设置折线图中折线线条颜色和折线点颜色的方法
2017/03/03 Javascript
Bootstrap栅格系统的使用详解
2017/10/30 Javascript
[40:29]2018DOTA2亚洲邀请赛 4.7总决赛 LGD vs Mineski 第一场
2018/04/10 DOTA
跟老齐学Python之做一个小游戏
2014/09/28 Python
Python中使用不同编码读写txt文件详解
2015/05/28 Python
Python实现简易Web爬虫详解
2018/01/03 Python
VSCode下配置python调试运行环境的方法
2018/04/06 Python
使用celery执行Django串行异步任务的方法步骤
2019/06/06 Python
详解Django将秒转换为xx天xx时xx分
2019/09/27 Python
Python如何使用BeautifulSoup爬取网页信息
2019/11/26 Python
使用Python发现隐藏的wifi
2020/03/04 Python
Python如何将函数值赋给变量
2020/04/28 Python
采购助理岗位职责
2014/02/16 职场文书
我们的节日端午节活动方案
2014/03/02 职场文书
培训研修方案
2014/06/06 职场文书
社团活动总结怎么写
2014/06/30 职场文书
民主生活会批评与自我批评总结
2014/10/17 职场文书
暑期社会实践证明书
2014/11/17 职场文书
就业导师推荐信范文
2015/03/27 职场文书
Go语言中break label与goto label的区别
2021/04/28 Golang
吉利入股戴姆勒后smart“长大了”
2022/04/21 数码科技
el-form每行显示两列底部按钮居中效果的实现
2022/08/05 HTML / CSS