sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之编写类之四再论继承
Oct 11 Python
零基础写python爬虫之打包生成exe文件
Nov 06 Python
python实现应用程序在右键菜单中添加打开方式功能
Jan 09 Python
用Python将IP地址在整型和字符串之间轻松转换
Mar 22 Python
Python子类继承父类构造函数详解
Feb 19 Python
Centos7 下安装最新的python3.8
Oct 28 Python
Python pickle模块实现对象序列化
Nov 22 Python
Python中base64与xml取值结合问题
Dec 22 Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 Python
django使用graphql的实例
Sep 02 Python
tensorflow2.0教程之Keras快速入门
Feb 20 Python
Python加密技术之RSA加密解密的实现
Apr 08 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
PHPMYADMIN 简明安装教程 推荐
2010/03/07 PHP
PHP数组操作类实例
2015/07/11 PHP
实现PHP框架系列文章(6)mysql数据库方法
2016/03/04 PHP
php实现基于pdo的事务处理方法示例
2017/07/21 PHP
PHP简单实现二维数组赋值与遍历功能示例
2017/10/19 PHP
PHP生成zip压缩包的常用方法示例
2019/08/22 PHP
json 实例详细说明教程
2009/10/31 Javascript
DWZ刷新dialog解决方法
2013/03/03 Javascript
JavaScript中setMonth()方法的使用详解
2015/06/11 Javascript
js+css绘制颜色动态变化的圈中圈效果
2016/01/27 Javascript
JavaScript蒙板(model)功能的简单实现代码
2016/08/04 Javascript
Vue的百度地图插件尝试使用
2017/09/06 Javascript
jQuery实现倒计时功能 jQuery实现计时器功能
2017/09/19 jQuery
浅谈node的事件机制
2017/10/09 Javascript
微信小程序实现点击按钮移动view标签的位置功能示例【附demo源码下载】
2017/12/06 Javascript
微信小程序 Animation实现图片旋转动画示例
2018/08/22 Javascript
Vue加载json文件的方法简单示例
2019/01/28 Javascript
基于Proxy的小程序状态管理实现
2019/06/14 Javascript
JS中数据结构与算法---排序算法(Sort Algorithm)实例详解
2019/06/17 Javascript
javascript操作向表格中动态加载数据
2020/08/27 Javascript
[01:01:29]2018DOTA2亚洲邀请赛 4.4 淘汰赛 VP vs Liquid 第一场
2018/04/05 DOTA
Python 文件和输入输出小结
2013/10/09 Python
python实现基于SVM手写数字识别功能
2020/05/27 Python
python 3.3 下载固定链接文件并保存的方法
2018/12/18 Python
100行Python代码实现每天不同时间段定时给女友发消息
2019/09/27 Python
TensorFlow tf.nn.conv2d实现卷积的方式
2020/01/03 Python
Python日期格式和字符串格式相互转换的方法
2020/02/18 Python
世界闻名的衬衫制造商:Savile Row Company
2018/07/30 全球购物
西北政法大学自主招生自荐信
2014/01/29 职场文书
三项教育活动实施方案
2014/03/30 职场文书
幼儿园家长寄语
2014/04/02 职场文书
甜品店创业计划书
2014/08/14 职场文书
优秀党员自我评价范文
2014/09/15 职场文书
分居协议书范本
2014/11/03 职场文书
2019最新校园运动会广播稿!
2019/06/28 职场文书
KVM基础命令详解
2022/04/30 Servers