sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现求最大公约数及判断素数的方法
May 26 Python
python基于phantomjs实现导入图片
May 13 Python
python内置函数:lambda、map、filter简单介绍
Nov 16 Python
django使用html模板减少代码代码解析
Dec 12 Python
Python3.6实现连接mysql或mariadb的方法分析
May 18 Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 Python
Python拼接字符串的7种方式详解
Mar 19 Python
python 实现读取csv数据,分类求和 再写进 csv
May 18 Python
关于python tushare Tkinter构建的简单股票可视化查询系统(Beta v0.13)
Oct 19 Python
PyQt5的QWebEngineView使用示例
Oct 20 Python
Ubuntu权限不足无法创建文件夹解决方案
Nov 14 Python
python文件路径操作方法总结
Dec 21 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
桌面中心(四)数据显示
2006/10/09 PHP
ThinkPHP 防止表单重复提交的方法
2011/08/08 PHP
php简单判断文本编码的方法
2015/07/30 PHP
深入php内核之php in array
2015/11/10 PHP
javascript 操作Word和Excel的实现代码
2009/10/26 Javascript
JavaScript CSS修改学习第六章 拖拽
2010/02/19 Javascript
顶部缓冲下拉菜单导航特效的JS代码
2013/08/27 Javascript
如何学习Javascript入门指导
2013/11/01 Javascript
Ajax同步与异步传输的示例代码
2013/11/21 Javascript
jquery 隐藏与显示tr标签示例代码
2014/06/06 Javascript
基于JS实现简单的样式切换效果代码
2015/09/04 Javascript
JS使用单链表统计英语单词出现次数
2016/06/16 Javascript
微信小程序 教程之数据绑定
2016/10/18 Javascript
jQuery简单绑定单个事件的方法示例
2017/06/10 jQuery
详解利用Angular实现多团队模块化SPA开发框架
2017/11/27 Javascript
手挽手带你学React之React-router4.x的使用
2019/02/14 Javascript
微信小程序学习笔记之目录结构、基本配置图文详解
2019/03/28 Javascript
使用Vue-cli3.0创建的项目 如何发布npm包
2019/10/10 Javascript
Vue中nprogress页面加载进度条的方法实现
2020/11/13 Javascript
[02:05]2014DOTA2西雅图邀请赛 老队长全明星大猜想谁不服就按进显示器
2014/07/08 DOTA
[01:10:58]KG vs TNC 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
Python制作钉钉加密/解密工具
2016/12/07 Python
Python实现在线音乐播放器
2017/03/03 Python
Python GUI Tkinter简单实现个性签名设计
2018/06/19 Python
Python3简单爬虫抓取网页图片代码实例
2019/08/26 Python
Python爬取破解无线网络wifi密码过程解析
2019/09/17 Python
Tensorflow实现多GPU并行方式
2020/02/03 Python
ITK 实现多张图像转成单个nii.gz或mha文件案例
2020/07/01 Python
python3:excel操作之读取数据并返回字典 + 写入的案例
2020/09/01 Python
浅谈Selenium+Webdriver 常用的元素定位方式
2021/01/13 Python
ALEX AND ANI:手镯,项链,耳环和更多
2017/04/20 全球购物
英国森林假期:Forest Holidays
2021/01/01 全球购物
C#面试常见问题
2013/02/25 面试题
2014年会演讲稿范文
2014/01/06 职场文书
专业技术职务聘任书
2014/03/29 职场文书
MySql数据库 查询时间序列间隔
2022/05/11 MySQL