sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程语言的35个与众不同之处(语言特征和使用技巧)
Jul 07 Python
跟老齐学Python之不要红头文件(1)
Sep 28 Python
python实现根据用户输入从电影网站获取影片信息的方法
Apr 07 Python
TensorFlow安装及jupyter notebook配置方法
Sep 08 Python
Python面向对象程序设计之继承与多继承用法分析
Jul 13 Python
centos6.5安装python3.7.1之后无法使用pip的解决方案
Feb 14 Python
django使用haystack调用Elasticsearch实现索引搜索
Jul 24 Python
Python实现决策树并且使用Graphviz可视化的例子
Aug 09 Python
python二进制读写及特殊码同步实现详解
Oct 11 Python
python 如何去除字符串头尾的多余符号
Nov 19 Python
Python.append()与Python.expand()用法详解
Dec 18 Python
python实现飞机大战项目
Mar 11 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
PHP中使用xmlreader读取xml数据示例
2014/12/29 PHP
php获取字符串中各个字符出现次数的方法
2015/02/23 PHP
百度工程师讲PHP函数的实现原理及性能分析(一)
2015/05/13 PHP
PHP页面间传递值和保持值的方法
2016/08/24 PHP
利用php做服务器和web前端的界面进行交互
2016/10/31 PHP
PHP简单实现防止SQL注入的方法
2018/03/13 PHP
PHP使用curl_multi_select解决curl_multi网页假死问题的方法
2018/08/15 PHP
JavaScript 事件对象的实现
2009/07/13 Javascript
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
JS异常处理的一个想法(sofish)
2013/03/14 Javascript
JQuery以JSON方式提交数据到服务端示例代码
2014/05/05 Javascript
table行随鼠标移动变色示例
2014/05/07 Javascript
js获取字符串字节数方法小结
2015/06/09 Javascript
初步使用bootstrap快速创建页面
2016/03/03 Javascript
用js读写cookie的简单方法(推荐)
2016/08/08 Javascript
微信小程序 实例开发总结
2017/04/26 Javascript
JavaScript变量类型以及变量作用域详解
2017/08/14 Javascript
VueCli3构建TS项目的方法步骤
2018/11/07 Javascript
element中el-container容器与div布局区分详解
2020/05/13 Javascript
vue中可编辑树状表格的实现代码
2020/10/31 Javascript
JavaScript函数柯里化实现原理及过程
2020/12/02 Javascript
[01:51]2014DOTA2西雅图邀请赛 MVP 外卡赛black场间采访
2014/07/09 DOTA
[01:02:06]LGD vs Mineski Supermajor 胜者组 BO3 第二场 6.5
2018/06/06 DOTA
深入理解python中的select模块
2017/04/23 Python
django 删除数据库表后重新同步的方法
2018/05/27 Python
Flask框架web开发之零基础入门
2018/12/10 Python
PyTorch之图像和Tensor填充的实例
2019/08/18 Python
Pytorch中的variable, tensor与numpy相互转化的方法
2019/10/10 Python
基于python实现微信好友数据分析(简单)
2020/02/16 Python
python使用re模块爬取豆瓣Top250电影
2020/10/20 Python
浅析两列自适应布局的3种思路
2016/05/03 HTML / CSS
英国文胸专家:AmpleBosom.com
2018/02/06 全球购物
爱我中华教学反思
2014/04/28 职场文书
小学英语课教学反思
2016/02/15 职场文书
Oracle数据库中通用的函数实例详解
2022/03/25 Oracle
Redis中key的过期删除策略和内存淘汰机制
2022/04/12 Redis