带你学习Python如何实现回归树模型


Posted in Python onJuly 16, 2020

所谓的回归树模型其实就是用树形模型来解决回归问题,树模型当中最经典的自然还是决策树模型,它也是几乎所有树模型的基础。虽然基本结构都是使用决策树,但是根据预测方法的不同也可以分为两种。第一种,树上的叶子节点就对应一个预测值和分类树对应,这一种方法称为回归树。第二种,树上的叶子节点对应一个线性模型,最后的结果由线性模型给出。这一种方法称为模型树。

今天我们先来看看其中的回归树。

回归树模型

CART算法的核心精髓就是我们每次选择特征对数据进行拆分的时候,永远对数据集进行二分。无论是离散特征还是连续性特征,一视同仁。CART还有一个特点是使用GINI指数而不是信息增益或者是信息增益比来选择拆分的特征,但是在回归问题当中用不到这个。因为回归问题的损失函数是均方差,而不是交叉熵,很难用熵来衡量连续值的准确度。

在分类树当中,我们一个叶子节点代表一个类别的预测值,这个类别的值是落到这个叶子节点当中训练样本的类别的众数,也就是出现频率最高的类别。在回归树当中,叶子节点对应的自然就是一个连续值。这个连续值是落到这个节点的训练样本的均值,它的误差就是这些样本的均方差。

另外,之前我们在选择特征的划分阈值的时候,对阈值的选择进行了优化,只选择了那些会引起预测类别变化的阈值。但是在回归问题当中,由于预测值是一个浮点数,所以这个优化也不存在了。整体上来说,其实回归树的实现难度比分类树是更低的。

实战

我们首先来加载数据,我们这次使用的是scikit-learn库当中经典的波士顿房价预测的数据。关于房价预测,kaggle当中也有一个类似的比赛,叫做:house-prices-advanced-regression-techniques。不过给出的特征更多,并且存在缺失等情况,需要我们进行大量的特征工程。感兴趣的同学可以自行研究一下。

首先,我们来获取数据,由于sklearn库当中已经有数据了,我们可以直接调用api获取,非常简单:

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
boston = load_boston()

X, y = boston.data, boston.target

我们输出前几条数据查看一下:

带你学习Python如何实现回归树模型

这个数据质量很高,sklearn库已经替我们做完了数据筛选与特征工程,直接拿来用就可以了。为了方便我们传递数据,我们将X和y合并在一起。由于y是一维的数组形式是不能和二维的X合并的,所以我们需要先对y进行reshape之后再进行合并。

y = y.reshape(-1, 1)
X = np.hstack((X, y))

hstack函数可以将两个np的array横向拼接,与之对应的是vstack,是将两个array纵向拼接,这个也是常规操作。合并之后,y作为新的一列添加在了X的后面。数据搞定了,接下来就要轮到实现模型了。

在实现决策树的主体部分之前,我们先来实现两个辅助函数。第一个辅助函数是计算一批样本的方差和,第二个辅助函数是获取样本的均值,也就是子节点的预测值。

def node_mean(X):
 return np.mean(X[:, -1])


def node_variance(X):
 return np.var(X[:, -1]) * X.shape[0]

这个搞定了之后,我们继续实现根据阈值拆分数据的函数。这个也可以复用之前的代码:

from collections import defaultdict
def split_dataset(X, idx, thred):
 split_data = defaultdict(list)
 for x in X:
  split_data[x[idx] < thred].append(x)
 return list(split_data.values()), list(split_data.keys())

接下来是两个很重要的函数,分别是get_thresholds和split_variance。顾名思义,第一个函数用来获取阈值,前面说了由于我们做的是回归模型,所以理论上来说特征的每一个取值都可以作为切分的依据。但是也不排除可能会存在多条数据的特征值相同的情况,所以我们对它进行去重。第二个函数是根据阈值对数据进行拆分,返回拆分之后的方差和。

def get_thresholds(X, i):
 return set(X[:, i].tolist())

# 每次迭代方差优化的底线
MINIMUM_IMPROVE = 2.0
# 每个叶子节点最少样本数
MINIMUM_SAMPLES = 10

def split_variance(dataset, idx, threshold):
 left, right = [], []
 n = dataset.shape[0]
 for data in dataset:
  if data[idx] < threshold:
   left.append(data)
  else:
   right.append(data)
 left, right = np.array(left), np.array(right)
 # 预剪枝
 # 如果拆分结果有一边过少,则返回None,防止过拟合
 if len(left) < MINIMUM_SAMPLES or len(right) < MINIMUM_SAMPLES:
  return None
 # 拆分之后的方差和等于左子树的方差和加上右子树的方差和
 # 因为是方差和而不是均方差,所以可以累加
 return node_variance(left) + node_variance(right)

这里我们用到了MINIMUM_SAMPLES这个参数,它是用来预剪枝用的。由于我们是回归模型,如果不对决策树的生长加以限制,那么很有可能得到的决策树的叶子节点和训练样本的数量一样多。这显然就陷入了过拟合了,对于模型的效果是有害无益的。所以我们要限制每个节点的样本数量,这个是一个参数,我们可以根据需要自行调整。

接下来,就是特征和阈值筛选的函数了。我们需要开发一个函数来遍历所有可以拆分的特征和阈值,对数据进行拆分,从所有特征当中找到最佳的拆分可能。

def choose_feature_to_split(dataset):
 n = len(dataset[0])-1
 m = len(dataset)
 # 记录最佳方差,特征和阈值
 var_ = node_variance(dataset)
 bestVar = float('inf')
 feature = -1
 thred = None
 for i in range(n):
  threds = get_thresholds(dataset, i)
  for t in threds:
   # 遍历所有的阈值,计算每个阈值的variance
   v = split_variance(dataset, i, t)
   # 如果v等于None,说明拆分过拟合了,跳过
   if v is None:
    continue
   if v < bestVar:
    bestVar, feature, thred = v, i, t
 # 如果最好的拆分效果达不到要求,那么就不拆分,控制子树的数量
 if var_ - bestVar < MINIMUM_IMPROVE:
  return None, None
 return feature, thred

和上面一样,这个函数当中也用到了一个预剪枝的参数MINIMUM_IMPROVE,它衡量的是每一次生成子树带来的收益。当某一次生成子树带来的收益小于某个值的时候,说明收益很小,并不划算,所以我们就放弃这次子树的生成。这也是预剪枝的一种。

这些都搞定了之后,就可以来建树了。建树的过程和之前类似,只是我们这一次的数据当中没有特征的name,所以我们去掉特征名称的相关逻辑。

def create_decision_tree(dataset):
 dataset = np.array(dataset)
 
 # 如果当前数量小于10,那么就不再继续划分了
 if dataset.shape[0] < MINIMUM_SAMPLES:
  return node_mean(dataset)
 
 # 记录最佳拆分的特征和阈值
 fidx, th = choose_feature_to_split(dataset)
 if fidx is None:
  return th
 
 node = {}
 node['feature'] = fidx
 node['threshold'] = th
 
 # 递归建树
 split_data, vals = split_dataset(dataset, fidx, th)
 for data, val in zip(split_data, vals):
  node[val] = create_decision_tree(data)
 return node

我们来完整测试一下建树,首先我们需要对原始数据进行拆分。将原始数据拆分成训练数据和测试数据,由于我们的场景比较简单,就不设置验证数据了。

拆分数据不用我们自己实现,sklearn当中提供了相应的工具,我们直接调用即可:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)

我们一般用到的参数就两个,一个是test_size,它可以是一个整数也可以是一个浮点数。如果是整数,代表的是测试集的样本数量。如果是一个0-1.0的浮点数,则代表测试集的占比。random_state是生成随机数的时候用到的随机种子。

带你学习Python如何实现回归树模型

我们输出一下生成的树,由于数据量比较大,可以看到一颗庞大的树结构。建树的部分实现了之后,最后剩下的就是预测的部分了。

预测部分的代码和之前分类树相差不大,整体的逻辑完全一样,只是去掉了feature_names的相关逻辑。

def classify(node, data):
 key = node['feature']
 pred = None
 thred = node['threshold']

 if isinstance(node[data[key] < thred], dict):
  pred = classify(node[data[key] < thred], data)
 else:
  pred = node[data[key] < thred]
   
 # 放置pred为空,挑选一个叶子节点作为替补
 if pred is None:
  for key in node:
   if not isinstance(node[key], dict):
    pred = node[key]
    break
 return pred

由于这个函数一次只能接受一条数据,如果我们想要批量预测的话还不行,所以最好的话再实现一个批量预测的predict函数比较好。

def predict(node, X):
 y_pred = []
 for x in X:
  y = classify(node, x)
  y_pred.append(y)
 return np.array(y_pred)

后剪枝

后剪枝的英文原文是post-prune,但是翻译成事后剪枝也有点奇怪。anyway,我们就用后剪枝这个词好了。

在回归树当中,我们利用的思想非常朴素,在建树的时候建立一棵尽量复杂庞大的树。然后在通过测试集对这棵树进行修剪,修剪的逻辑也非常简单,我们判断一棵子树存在分叉和没有分叉单独成为叶子节点时的误差,如果修剪之后误差更小,那么我们就减去这棵子树。

整个剪枝的过程和建树的过程一样,从上到下,递归执行。

整个逻辑很好理解,我们直接来看代码:

def is_dict(node):
 return isinstance(node, dict)


def prune(node, testData):
 testData = np.array(testData)
 if testData.shape[0] == 0:
  return node
 
 # 拆分数据
 split_data, _ = split_dataset(testData, node['feature'], node['threshold'])
 # 对左右子树递归修剪
 if is_dict(node[0]):
  node[0] = prune(node[0], split_data[0])
 if is_dict(node[1]) and len(split_data) > 1:
  node[1] = prune(node[1], split_data[1])

 # 如果左右都是叶子节点,那么判断当前子树是否需要修剪
 if len(split_data) > 1 and not is_dict(node[0]) and not is_dict(node[1]):
  # 计算修剪前的方差和
  baseError = np.sum(np.power(np.array(split_data[0])[:, -1] - node[0], 2)) + np.sum(np.power(np.array(split_data[1])[:, -1] - node[1], 2))
  # 计算修剪后的方差和
  meanVal = (node[0] + node[1]) / 2
  mergeError = np.sum(np.power(meanVal - testData[:, -1], 2))
  if mergeError < baseError:
   return meanVal
  else:
   return node
 return node

最后,我们对修剪之后的效果做一下验证:

带你学习Python如何实现回归树模型

从图中可以看到,修剪之前我们在测试数据上的均方差是19.65,而修剪之后降低到了19.48。从数值上来看是有效果的,只是由于我们的训练数据比较少,同时进行了预剪枝,影响了后剪枝的效果。但是对于实际的机器学习工程来说,一个方法只要是有明确效果的,在代价可以承受的范围内,它就是有价值的,千万不能觉得提升不明显,而随便否定一个方法。

这里计算均方差的时候用到了sklearn当中的一个库函数mean_square_error,从名字当中我们也可以看得出来它的用途,它可以对两个Numpy的array计算均方差

总结

关于回归树模型的相关内容到这里就结束了,我们不仅亲手实现了模型,而且还在真实的数据集上做了实验。如果你是亲手实现的模型的代码,相信你一定会有很多收获。

虽然从实际运用来说我们几乎不会使用树模型来做回归任务,但是回归树模型本身是非常有意义的。因为在它的基础上我们发展出了很多效果更好的模型,比如大名鼎鼎的GBDT。因此理解回归树对于我们后续进阶的学习是非常重要的。在深度学习普及之前,其实大多数高效果的模型都是以树模型为基础的,比如随机森林、GBDT、Adaboost等等。可以说树模型撑起了机器学习的半个时代,这么说相信大家应该都能理解它的重要性了吧。

今天的文章就到这里,如果喜欢本文,可以的话,请点个关注,给我一点鼓励,也方便获取更多文章。

以上就是带你学习Python如何实现回归树模型的详细内容,更多关于Python实现回归树模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现多线程采集的2个代码例子
Jul 07 Python
python模拟事件触发机制详解
Jan 19 Python
Python递归实现汉诺塔算法示例
Mar 19 Python
python实现周期方波信号频谱图
Jul 21 Python
Python基础之条件控制操作示例【if语句】
Mar 23 Python
python下的opencv画矩形和文字注释的实现方法
Jul 09 Python
window环境pip切换国内源(pip安装异常缓慢的问题)
Dec 31 Python
Tensorflow训练MNIST手写数字识别模型
Feb 13 Python
python递归调用中的坑:打印有值, 返回却None
Mar 16 Python
python如何删除文件、目录
Jun 23 Python
解决python运行效率不高的问题
Jul 20 Python
python中__slots__节约内存的具体做法
Jul 04 Python
MATLAB数学建模之画图汇总
Jul 16 #Python
浅析Python迭代器的高级用法
Jul 16 #Python
python 使用递归的方式实现语义图片分割功能
Jul 16 #Python
Django serializer优化类视图的实现示例
Jul 16 #Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 #Python
Python实现GIF图倒放
Jul 16 #Python
浅谈python处理json和redis hash的坑
Jul 16 #Python
You might like
一个典型的PHP分页实例代码分享
2011/07/28 PHP
php判断linux下程序问题实例
2015/07/09 PHP
php自定义函数br2nl实现将html中br换行符转换为文本输入中换行符的方法【与函数nl2br功能相反】
2017/02/17 PHP
建立良好体验度的Web注册系统ajax
2007/07/09 Javascript
JQuery datepicker 使用方法
2011/05/20 Javascript
Javascript 按位左移运算符使用介绍(
2014/02/04 Javascript
深入分析js的冒泡事件
2014/12/05 Javascript
jquery简单的弹出层浮动层代码
2015/04/27 Javascript
JavaScript中实现无缝滚动、分享到侧边栏实例代码
2016/04/06 Javascript
JavaScript:Date类型全面解析
2016/05/19 Javascript
原生js实现弹出层效果
2017/01/20 Javascript
实现div内部滚动条滚动到底部和顶部的代码
2017/11/15 Javascript
对vue.js中this.$emit的深入理解
2018/02/23 Javascript
js replace 全局替换的操作方法
2018/06/12 Javascript
给localStorage设置一个过期时间的方法分享
2018/11/06 Javascript
vue2之简易的pc端短信验证码的问题及处理方法
2019/06/03 Javascript
Vue 中获取当前时间并实时刷新的实现代码
2020/05/12 Javascript
如何在postman中添加cookie信息步骤解析
2020/06/30 Javascript
基于Vant UI框架实现时间段选择器
2020/12/24 Javascript
[03:47]2015国际邀请赛第三日现场精彩回顾
2015/08/08 DOTA
python SSH模块登录,远程机执行shell命令实例解析
2018/01/12 Python
详解Python中的动态属性和特性
2018/04/07 Python
Window环境下Scrapy开发环境搭建
2018/11/18 Python
解决python执行不输出系统命令弹框的问题
2019/06/24 Python
python实现ftp文件传输系统(案例分析)
2020/03/20 Python
Python根据字符串调用函数过程解析
2020/11/05 Python
html5 input属性使用示例
2013/06/28 HTML / CSS
德国电子商城:ComputerUniverse
2017/04/21 全球购物
Ajax主要包含了哪些技术
2014/06/12 面试题
环卫工作汇报材料
2014/10/28 职场文书
整改落实情况汇报材料
2014/10/29 职场文书
2015入党自传格式范文
2015/06/26 职场文书
如何使用Maxwell实时同步mysql数据
2021/04/08 MySQL
Python+Appium自动化测试的实战
2021/06/30 Python
Python数据结构之队列详解
2022/03/21 Python
使用Postman测试需要授权的接口问题
2022/06/21 Java/Android