sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现检测文件MD5值的方法示例
Apr 11 Python
Python单元测试实例详解
May 25 Python
python实现定时提取实时日志程序
Jun 22 Python
django项目搭建与Session使用详解
Oct 10 Python
Form表单及django的form表单的补充
Jul 25 Python
用OpenCV将视频分解成单帧图片,图片合成视频示例
Dec 10 Python
python3将变量写入SQL语句的实现方式
Mar 02 Python
Python Django搭建网站流程图解
Jun 13 Python
Python爬取微信小程序通用方法代码实例详解
Sep 29 Python
python获取命令行参数实例方法讲解
Nov 02 Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 Python
python获取对象信息的实例详解
Jul 07 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
用PHP和ACCESS写聊天室(二)
2006/10/09 PHP
php根据年月获取季度的方法
2014/03/31 PHP
php与flash as3 socket通信传送文件实现代码
2014/08/16 PHP
PHP文件缓存内容保存格式实例分析
2014/08/20 PHP
PHP实现操作redis的封装类完整实例
2015/11/14 PHP
php实现的错误处理封装类实例
2017/06/20 PHP
PHP连接sftp并下载文件的方法教程
2018/08/26 PHP
javascript flash下fromCharCode和charCodeAt方法使用说明
2008/01/12 Javascript
密码强度检测效果实现原理与代码
2013/01/04 Javascript
用jquery实现输入框获取焦点消失文字
2013/04/27 Javascript
jquery获取子节点和父节点的示例代码
2013/09/10 Javascript
老司机带你解读jQuery插件开发流程
2016/05/16 Javascript
AngularJS 模型详细介绍及实例代码
2016/07/27 Javascript
Vue数据驱动模拟实现3
2017/01/11 Javascript
nodejs开发微信小程序实现密码加密
2017/07/11 NodeJs
JS+WCF实现进度条实时监测数据加载量的方法详解
2017/12/19 Javascript
vue自定义指令directive的使用方法
2019/04/07 Javascript
JavaScript命令模式原理与用法实例详解
2020/03/10 Javascript
axios解决高并发的方法:axios.all()与axios.spread()的操作
2020/11/09 Javascript
Python装饰器decorator用法实例
2014/11/10 Python
python中字符串前面加r的作用
2015/06/04 Python
基于Python的文件类型和字符串详解
2017/12/21 Python
selenium3+python3环境搭建教程图解
2018/12/07 Python
详解python中自定义超时异常的几种方法
2019/07/29 Python
Python流程控制 if else实现解析
2019/09/02 Python
Python pandas RFM模型应用实例详解
2019/11/20 Python
Python enumerate函数遍历数据对象组合过程解析
2019/12/11 Python
opencv python图像梯度实例详解
2020/02/04 Python
Django之富文本(获取内容,设置内容方式)
2020/05/21 Python
html5+css3进度条倒计时动画特效代码【推荐】
2016/03/08 HTML / CSS
财务会计应届生求职信
2013/11/24 职场文书
公务员综合考察材料
2014/02/01 职场文书
村干部培训方案
2014/05/02 职场文书
优秀求职信
2014/05/29 职场文书
教师个人教学总结
2015/02/11 职场文书
趣味运动会新闻稿
2015/07/17 职场文书