浅谈sklearn中predict与predict_proba区别


Posted in Python onJune 28, 2020

predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1。

predict 直接返回的是预测 的标签。

具体见下面示例:

# conding :utf-8 
from sklearn.linear_model import LogisticRegression 
import numpy as np 
x_train = np.array([[1,2,3], 
          [1,3,4], 
          [2,1,2], 
          [4,5,6], 
          [3,5,3], 
          [1,7,2]]) 
 
y_train = np.array([3, 3, 3, 2, 2, 2]) 
 
x_test = np.array([[2,2,2], 
          [3,2,6], 
          [1,7,4]]) 
 
clf = LogisticRegression() 
clf.fit(x_train, y_train) 
 
# 返回预测标签 
print(clf.predict(x_test)) 
 
# 返回预测属于某标签的概率 
print(clf.predict_proba(x_test)) 
 
# [2 3 2] 
#
# [[0.56651809 0.43348191] 
# [0.15598162 0.84401838] 
# [0.86852502 0.13147498]] 
# 分析结果: 
# 标签是 2,3 共两个,所以predict_proba返回的为2列,且是排序的(第一列为标签2,第二列为标签3),
# 返回矩阵的行数是测试样本个数 因此为3行
# 预测[2,2,2]的标签是2的概率为0.56651809,3的概率为0.43348191 
# 
# 预测[3,2,6]的标签是2的概率为0.15598162,3的概率为0.84401838 
# 
# 预测[1,7,4]的标签是2的概率为0.86852502,3的概率为0.13147498

补充知识:sklearn中predict与predict_proba的识别结果不一致

今天训练了好久的决策树模型在测试的时候发现个bug,使用predict得到的结果居然不是predict_proba中最大数值的索引!因为脚本中需要模型的置信度,所以希望拿到predict_proba的类别概率。

经过胡乱分析发现predict_proba得到的维度比总类别数少了几个,经过测试发现就是这个造成的,即训练集中有部分类别样本数为0。这个问题比较隐蔽,记录一下方便天涯沦落人绕坑。

Tip:在sklearn的train_test_split中有一个参数可以强制测试集和训练集的数据分布一致,也就不会导致缺类别的问题。

以上这篇浅谈sklearn中predict与predict_proba区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
MySQLdb ImportError: libmysqlclient.so.18解决方法
Aug 21 Python
Python实现数通设备端口使用情况监控实例
Jul 15 Python
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
Python实现堆排序的方法详解
May 03 Python
Python实现的十进制小数与二进制小数相互转换功能
Oct 12 Python
python3爬取各类天气信息
Feb 24 Python
Python使用re模块实现信息筛选的方法
Apr 29 Python
用Python一键搭建Http服务器的方法
Jun 01 Python
Python实现简单石头剪刀布游戏
Jan 20 Python
Python实现数据结构线性链表(单链表)算法示例
May 04 Python
Python vtk读取并显示dicom文件示例
Jan 13 Python
Python中使用filter过滤列表的一个小技巧分享
May 02 Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 #Python
pytorch 常用函数 max ,eq说明
Jun 28 #Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 #Python
You might like
PHP 定界符 使用技巧
2009/06/14 PHP
关于php程序报date()警告的处理(date_default_timezone_set)
2013/10/22 PHP
PHP编译安装中遇到的两个错误和解决方法
2014/08/20 PHP
PHP使用 Imagick 扩展实现图片合成,圆角处理功能示例
2019/09/09 PHP
Js 订制自己的AlertBox(信息提示框)
2009/01/09 Javascript
Jquery 最近浏览过的商品的功能实现代码
2010/05/14 Javascript
jquery 查找新建元素代码
2010/07/06 Javascript
js中settimeout方法加参数
2014/02/28 Javascript
jQuery中each()方法用法实例
2014/12/27 Javascript
jquery实现无限分级横向导航菜单的方法
2015/03/12 Javascript
JQuery选择器、过滤器大整理
2015/05/26 Javascript
基于dropdown.js实现的两款美观大气的二级导航菜单
2015/09/02 Javascript
学习JavaScript设计模式之单例模式
2016/01/19 Javascript
jQuery模拟物体自由落体运动(附演示与demo源码下载)
2016/01/21 Javascript
vue 动态修改a标签的样式的方法
2018/01/18 Javascript
JS实现匀速与减速缓慢运动的动画效果封装示例
2018/08/27 Javascript
Vue传参一箩筐(页面、组件)
2019/04/04 Javascript
详解关于html,css,js三者的加载顺序问题
2019/04/10 Javascript
[01:50]2014DOTA2西雅图邀请赛 专访欢乐周宝龙
2014/07/08 DOTA
[03:11]完美世界DOTA2联赛PWL DAY8集锦
2020/11/09 DOTA
Python获取远程文件大小的函数代码分享
2014/05/13 Python
Python常用模块用法分析
2014/09/08 Python
python实现简单购物商城
2016/05/21 Python
利用selenium 3.7和python3添加cookie模拟登陆的实现
2017/11/20 Python
Python socket模块ftp传输文件过程解析
2019/11/05 Python
浅谈pandas.cut与pandas.qcut的使用方法及区别
2020/03/03 Python
python代码xml转txt实例
2020/03/10 Python
python如何进入交互模式
2020/07/06 Python
jupyter notebook 写代码自动补全的实现
2020/11/02 Python
详解h5页面在不同ios设备上的问题总结
2019/03/01 HTML / CSS
Kathmandu新西兰官网:新西兰户外运动品牌
2019/07/27 全球购物
人事科岗位职责范本
2014/03/02 职场文书
yy司仪主持词
2014/03/22 职场文书
2014第二批党员干部对照“四风”找差距检查材料思想汇报
2014/09/18 职场文书
谁动了我的奶酪读书笔记
2015/06/30 职场文书
2019年英语版感谢信(8篇)
2019/09/29 职场文书