sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程之属性和方法实例详解
May 19 Python
在Django框架中编写Contact表单的教程
Jul 17 Python
python 分离文件名和路径以及分离文件名和后缀的方法
Oct 21 Python
python调用百度地图WEB服务API获取地点对应坐标值
Jan 16 Python
超简单使用Python换脸实例
Mar 27 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
python Dijkstra算法实现最短路径问题的方法
Sep 19 Python
python 字段拆分详解
Dec 17 Python
Python使用ElementTree美化XML格式的操作
Mar 06 Python
django rest framework serializer返回时间自动格式化方法
Mar 31 Python
QT5 Designer 打不开的问题及解决方法
Aug 20 Python
Python实现DBSCAN聚类算法并样例测试
Jun 22 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
sql注入与转义的php函数代码
2013/06/17 PHP
ThinkPHP模板引擎之导入资源文件方法详解
2014/06/18 PHP
用roll.js实现的图片自动滚动+鼠标触动的特效
2007/03/18 Javascript
jquery 应用代码 方便的排序功能
2010/02/06 Javascript
基于jquery的lazy loader插件实现图片的延迟加载[简单使用]
2011/05/07 Javascript
JS 屏蔽按键效果与改变按键效果的示例代码
2013/12/24 Javascript
Javascript:为input设置readOnly属性(示例讲解)
2013/12/25 Javascript
JS将数字转换成三位逗号分隔的样式(示例代码)
2014/02/19 Javascript
IE6中链接A的href为javascript协议时不在当前页面跳转
2014/06/05 Javascript
jQuery的one()方法用法实例
2015/01/19 Javascript
javascript实现验证身份证号的有效性并提示
2015/04/30 Javascript
基于jQuery实现的菜单切换效果
2015/10/16 Javascript
jQuery实现无限往下滚动效果代码
2016/04/16 Javascript
react.js 获取真实的DOM节点实例(必看)
2017/04/17 Javascript
使用Vue组件实现一个简单弹窗效果
2018/04/23 Javascript
JS文件中加载jquery.js的实例代码
2018/05/05 jQuery
vue 中swiper的使用教程
2018/05/22 Javascript
Javascript实现异步编程的过程
2018/06/18 Javascript
详解如何配置vue-cli3.0的vue.config.js
2018/08/23 Javascript
微信小程序实现banner图轮播效果
2020/06/28 Javascript
vue动态路由:路由参数改变,视图不更新问题的解决
2019/11/05 Javascript
[03:24]CDEC.Y赛前采访 努力备战2016国际邀请赛中国区预选赛
2016/06/25 DOTA
深入解读Python解析XML的几种方式
2016/02/16 Python
基于pandas数据样本行列选取的方法
2018/04/20 Python
Flask框架学习笔记之使用Flask实现表单开发详解
2019/08/12 Python
python 利用turtle模块画出没有角的方格
2019/11/23 Python
python输出数学符号实例
2020/05/11 Python
python3通过subprocess模块调用脚本并和脚本交互的操作
2020/12/05 Python
Python图像处理之膨胀与腐蚀的操作
2021/02/07 Python
HTML5如何使用SVG的方法示例
2019/01/11 HTML / CSS
高二美术教学反思
2014/01/14 职场文书
机械工程师岗位职责
2014/06/16 职场文书
2015年食堂工作总结报告
2015/04/23 职场文书
文明礼仪倡议书
2015/04/28 职场文书
公司老总年会致辞
2015/07/30 职场文书
2017春节晚会开幕词
2016/03/03 职场文书