sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的条件判断语句与循环语句用法小结
Mar 21 Python
简单谈谈Python中的元祖(Tuple)和字典(Dict)
Apr 21 Python
使用Python实现windows下的抓包与解析
Jan 15 Python
对python中的logger模块全面讲解
Apr 28 Python
pandas表连接 索引上的合并方法
Jun 08 Python
python 从csv读数据到mysql的实例
Jun 21 Python
基于python历史天气采集的分析
Feb 14 Python
Python面向对象之类和实例用法分析
Jun 08 Python
Python帮你微信头像任意添加装饰别再@微信官方了
Sep 25 Python
django的403/404/500错误自定义页面的配置方式
May 21 Python
python 元组的使用方法
Jun 09 Python
Jupyter安装链接aconda实现过程图解
Nov 02 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
php一次性删除前台checkbox多选内容的方法
2013/09/22 PHP
PHP static局部静态变量和全局静态变量总结
2014/03/02 PHP
深入浅出php socket编程
2015/05/13 PHP
php脚本守护进程原理与实现方法详解
2017/07/20 PHP
PHP设计模式之状态模式定义与用法详解
2018/04/02 PHP
关于ThinkPHP中的异常处理详解
2018/05/11 PHP
清华大学出版的事半功倍系列 javascript全部源代码
2007/05/04 Javascript
一个简单的JavaScript 日期计算算法
2009/09/11 Javascript
javascript void(0)的妙用
2009/10/21 Javascript
jquery 最简单易用的表单验证插件
2010/02/27 Javascript
jquery 触发a链接点击事件解决方案
2013/05/02 Javascript
jQuery中的$.ajax()方法应用
2014/05/06 Javascript
js实现动画特效的文字链接鼠标悬停提示的方法
2015/03/02 Javascript
基于jquery实现表格无刷新分页
2016/01/07 Javascript
详解JavaScript表单验证(E-mail 验证)
2016/03/31 Javascript
Jquery和BigFileUpload实现大文件上传及进度条显示
2016/06/27 Javascript
jQuery插件jqGrid动态获取列和列字段的方法
2017/03/03 Javascript
获取url中用&隔开的参数实例(分享)
2017/05/28 Javascript
ionic2自定义cordova插件开发以及使用(Android)
2017/06/19 Javascript
react在安卓中输入框被手机键盘遮挡问题的解决方法
2018/09/03 Javascript
微信小程序上传文件到阿里OSS教程
2019/05/20 Javascript
手把手带你搭建一个node cli的方法示例
2020/08/07 Javascript
基于vuex实现购物车功能
2021/01/10 Vue.js
Python分析学校四六级过关情况
2017/11/22 Python
python 数据生成excel导出(xlwt,wlsxwrite)代码实例
2019/08/23 Python
Scrapy 配置动态代理IP的实现
2020/09/28 Python
eBay德国站:eBay.de
2017/09/14 全球购物
部队学习十八大感言
2014/01/11 职场文书
生物制药专业自我鉴定
2014/02/19 职场文书
人力资源职位说明书
2014/07/29 职场文书
群众路线个人对照检查材料2014
2014/09/26 职场文书
商场圣诞节活动总结
2015/05/06 职场文书
教师岗位说明书
2015/09/30 职场文书
python实现图片九宫格分割的示例
2021/04/25 Python
低版本Druid连接池+MySQL驱动8.0导致线程阻塞、性能受限
2021/07/01 MySQL
详解Python+OpenCV绘制灰度直方图
2022/03/22 Python