sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用IPython下的Net-SNMP来管理类UNIX系统的教程
Apr 15 Python
Python聚类算法之基本K均值实例详解
Nov 20 Python
用Python实现KNN分类算法
Dec 22 Python
Python中str.join()简单用法示例
Mar 20 Python
python切片及sys.argv[]用法详解
May 25 Python
python求解数组中两个字符串的最小距离
Sep 27 Python
Django 路由系统URLconf的使用
Oct 11 Python
python学生信息管理系统(完整版)
Apr 05 Python
python 实现在tkinter中动态显示label图片的方法
Jun 13 Python
python 中如何获取列表的索引
Jul 02 Python
PHP统计代码行数的小代码
Sep 19 Python
在Django中预防CSRF攻击的操作
Mar 13 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
php 高效率写法 推荐
2010/02/21 PHP
PHP FATAL ERROR: CALL TO UNDEFINED FUNCTION BCMUL()解决办法
2014/05/04 PHP
TNC vs IO BO3 第一场2.13
2021/03/10 DOTA
jquery1.4 教程二 ajax方法的改进
2010/02/25 Javascript
十个迅速提升JQuery性能让你的JQuery跑得更快
2012/12/10 Javascript
Javascript代码在页面加载时的执行顺序介绍
2013/05/03 Javascript
JS Date函数整理方便使用
2013/10/23 Javascript
js向上无缝滚动,网站公告效果 具体代码
2013/11/18 Javascript
jquery中get和post的简单实例
2014/02/04 Javascript
正则表达式优化JSON字符串的技巧
2015/12/24 Javascript
浅析location.href跨窗口调用函数
2016/11/22 Javascript
jquery实现自定义图片裁剪功能【推荐】
2017/03/08 Javascript
Angular.js 4.x中表单Template-Driven Forms详解
2017/04/25 Javascript
Angular6新特性之Angular Material
2018/12/28 Javascript
如何使用less实现随机下雪动画详解
2019/01/02 Javascript
vux-scroller实现移动端上拉加载功能过程解析
2019/10/08 Javascript
[43:14]Liquid vs Optic 2018国际邀请赛淘汰赛BO3 第二场 8.21
2018/08/22 DOTA
python标准日志模块logging的使用方法
2013/11/01 Python
全面解读Python Web开发框架Django
2014/06/30 Python
wxpython中利用线程防止假死的实现方法
2014/08/11 Python
零基础写python爬虫之打包生成exe文件
2014/11/06 Python
Python中fnmatch模块的使用详情
2018/11/30 Python
python-Web-flask-视图内容和模板知识点西宁街
2019/08/23 Python
Windows上安装tensorflow  详细教程(图文详解)
2020/02/04 Python
Django解决frame拒绝问题的方法
2020/12/18 Python
The Hut美国/加拿大:英国领先的豪华在线百货商店
2019/03/26 全球购物
北美最大的零售退货翻新商:VIP Outlet
2019/11/21 全球购物
学生实习自我鉴定
2013/10/11 职场文书
小学安全教育月活动总结
2014/07/07 职场文书
办理护照工作证明
2014/10/10 职场文书
小学少先队辅导员述职报告
2015/01/10 职场文书
2015年车间主任工作总结
2015/05/21 职场文书
2015年依法治校工作总结
2015/07/27 职场文书
MySQL 表空间碎片的概念及相关问题解决
2021/05/07 MySQL
python关于集合的知识案例详解
2021/05/30 Python
Python使用socket去实现TCP客户端和TCP服务端
2022/04/12 Python