sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用tkinter实现ui界面简单实例
Jan 10 Python
Python获取系统默认字符编码的方法
Jun 04 Python
Python基于Flask框架配置依赖包信息的项目迁移部署
Mar 02 Python
python实现五子棋小游戏
Mar 25 Python
Python中super函数用法实例分析
Mar 18 Python
Python实现的删除重复文件或图片功能示例【去重】
Apr 23 Python
itchat-python搭建微信机器人(附示例)
Jun 11 Python
python实现单链表的方法示例
Sep 03 Python
Win10下python 2.7与python 3.7双环境安装教程图解
Oct 12 Python
Django连接数据库并实现读写分离过程解析
Nov 13 Python
python编写一个会算账的脚本的示例代码
Jun 02 Python
Pycharm-community-2020.2.3 社区版安装教程图文详解
Dec 08 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
(PHP实现)只使用++运算实现加法,减法,乘法,除法
2013/06/27 PHP
php之curl设置超时实例
2014/11/03 PHP
PHP中的print_r 与 var_dump 输出数组
2016/06/13 PHP
PHP实现多关键字加亮功能
2016/10/21 PHP
php7安装mongoDB扩展的方法分析
2017/08/02 PHP
鼠标经过显示二级菜单js特效
2013/08/13 Javascript
js获取微信版本号的方法
2015/05/12 Javascript
javascript中的altKey 和 Event属性大全
2015/11/06 Javascript
以jQuery中$.Deferred对象为例讲解promise对象是如何处理异步问题
2015/11/13 Javascript
关于在Servelet中如何获取当前时间的操作方法
2016/06/28 Javascript
AngularJS入门教程之链接与图片模板详解
2016/08/19 Javascript
浅谈js中StringBuffer类的实现方法及使用
2016/09/02 Javascript
JavaScript中数组Array方法详解
2017/02/27 Javascript
JS实现双击内容变为可编辑状态
2017/03/03 Javascript
原生JS封装animate运动框架的实例
2017/10/12 Javascript
jQuery实现的简单无刷新评论功能示例
2017/11/08 jQuery
node中modules.exports与exports导出的区别
2018/06/08 Javascript
JS实现求5的阶乘示例
2019/01/21 Javascript
微信小程序之侧边栏滑动实现过程解析(附完整源码)
2019/08/23 Javascript
JavaScript JSON使用原理及注意事项
2020/07/30 Javascript
Vue 实现对quill-editor组件中的工具栏添加title
2020/08/03 Javascript
vite2.0+vue3移动端项目实战详解
2021/03/03 Vue.js
python MysqlDb模块安装及其使用详解
2018/02/23 Python
python微信跳一跳系列之色块轮廓定位棋盘
2018/02/26 Python
利用Python实现kNN算法的代码
2019/08/16 Python
python3 map函数和filter函数详解
2019/08/26 Python
Pycharm 安装 idea VIM插件的图文教程详解
2020/02/21 Python
python GUI库图形界面开发之PyQt5访问系统剪切板QClipboard类详细使用方法与实例
2020/02/27 Python
AmazeUI折叠式卡片布局,整合内容列表、表格组件实现
2020/08/20 HTML / CSS
C++的几个面试题附答案
2016/08/03 面试题
四川internet信息高速公路(C#)笔试题
2012/02/29 面试题
优秀大学生职业生涯规划书
2014/02/27 职场文书
一份恶作剧的检讨书
2014/09/13 职场文书
通知书大全
2015/04/27 职场文书
Matplotlib可视化之添加让统计图变得简单易懂的注释
2021/06/11 Python
Android中的Launch Mode详情
2022/06/05 Java/Android