带你学习Python如何实现回归树模型


Posted in Python onJuly 16, 2020

所谓的回归树模型其实就是用树形模型来解决回归问题,树模型当中最经典的自然还是决策树模型,它也是几乎所有树模型的基础。虽然基本结构都是使用决策树,但是根据预测方法的不同也可以分为两种。第一种,树上的叶子节点就对应一个预测值和分类树对应,这一种方法称为回归树。第二种,树上的叶子节点对应一个线性模型,最后的结果由线性模型给出。这一种方法称为模型树。

今天我们先来看看其中的回归树。

回归树模型

CART算法的核心精髓就是我们每次选择特征对数据进行拆分的时候,永远对数据集进行二分。无论是离散特征还是连续性特征,一视同仁。CART还有一个特点是使用GINI指数而不是信息增益或者是信息增益比来选择拆分的特征,但是在回归问题当中用不到这个。因为回归问题的损失函数是均方差,而不是交叉熵,很难用熵来衡量连续值的准确度。

在分类树当中,我们一个叶子节点代表一个类别的预测值,这个类别的值是落到这个叶子节点当中训练样本的类别的众数,也就是出现频率最高的类别。在回归树当中,叶子节点对应的自然就是一个连续值。这个连续值是落到这个节点的训练样本的均值,它的误差就是这些样本的均方差。

另外,之前我们在选择特征的划分阈值的时候,对阈值的选择进行了优化,只选择了那些会引起预测类别变化的阈值。但是在回归问题当中,由于预测值是一个浮点数,所以这个优化也不存在了。整体上来说,其实回归树的实现难度比分类树是更低的。

实战

我们首先来加载数据,我们这次使用的是scikit-learn库当中经典的波士顿房价预测的数据。关于房价预测,kaggle当中也有一个类似的比赛,叫做:house-prices-advanced-regression-techniques。不过给出的特征更多,并且存在缺失等情况,需要我们进行大量的特征工程。感兴趣的同学可以自行研究一下。

首先,我们来获取数据,由于sklearn库当中已经有数据了,我们可以直接调用api获取,非常简单:

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
boston = load_boston()

X, y = boston.data, boston.target

我们输出前几条数据查看一下:

带你学习Python如何实现回归树模型

这个数据质量很高,sklearn库已经替我们做完了数据筛选与特征工程,直接拿来用就可以了。为了方便我们传递数据,我们将X和y合并在一起。由于y是一维的数组形式是不能和二维的X合并的,所以我们需要先对y进行reshape之后再进行合并。

y = y.reshape(-1, 1)
X = np.hstack((X, y))

hstack函数可以将两个np的array横向拼接,与之对应的是vstack,是将两个array纵向拼接,这个也是常规操作。合并之后,y作为新的一列添加在了X的后面。数据搞定了,接下来就要轮到实现模型了。

在实现决策树的主体部分之前,我们先来实现两个辅助函数。第一个辅助函数是计算一批样本的方差和,第二个辅助函数是获取样本的均值,也就是子节点的预测值。

def node_mean(X):
 return np.mean(X[:, -1])


def node_variance(X):
 return np.var(X[:, -1]) * X.shape[0]

这个搞定了之后,我们继续实现根据阈值拆分数据的函数。这个也可以复用之前的代码:

from collections import defaultdict
def split_dataset(X, idx, thred):
 split_data = defaultdict(list)
 for x in X:
  split_data[x[idx] < thred].append(x)
 return list(split_data.values()), list(split_data.keys())

接下来是两个很重要的函数,分别是get_thresholds和split_variance。顾名思义,第一个函数用来获取阈值,前面说了由于我们做的是回归模型,所以理论上来说特征的每一个取值都可以作为切分的依据。但是也不排除可能会存在多条数据的特征值相同的情况,所以我们对它进行去重。第二个函数是根据阈值对数据进行拆分,返回拆分之后的方差和。

def get_thresholds(X, i):
 return set(X[:, i].tolist())

# 每次迭代方差优化的底线
MINIMUM_IMPROVE = 2.0
# 每个叶子节点最少样本数
MINIMUM_SAMPLES = 10

def split_variance(dataset, idx, threshold):
 left, right = [], []
 n = dataset.shape[0]
 for data in dataset:
  if data[idx] < threshold:
   left.append(data)
  else:
   right.append(data)
 left, right = np.array(left), np.array(right)
 # 预剪枝
 # 如果拆分结果有一边过少,则返回None,防止过拟合
 if len(left) < MINIMUM_SAMPLES or len(right) < MINIMUM_SAMPLES:
  return None
 # 拆分之后的方差和等于左子树的方差和加上右子树的方差和
 # 因为是方差和而不是均方差,所以可以累加
 return node_variance(left) + node_variance(right)

这里我们用到了MINIMUM_SAMPLES这个参数,它是用来预剪枝用的。由于我们是回归模型,如果不对决策树的生长加以限制,那么很有可能得到的决策树的叶子节点和训练样本的数量一样多。这显然就陷入了过拟合了,对于模型的效果是有害无益的。所以我们要限制每个节点的样本数量,这个是一个参数,我们可以根据需要自行调整。

接下来,就是特征和阈值筛选的函数了。我们需要开发一个函数来遍历所有可以拆分的特征和阈值,对数据进行拆分,从所有特征当中找到最佳的拆分可能。

def choose_feature_to_split(dataset):
 n = len(dataset[0])-1
 m = len(dataset)
 # 记录最佳方差,特征和阈值
 var_ = node_variance(dataset)
 bestVar = float('inf')
 feature = -1
 thred = None
 for i in range(n):
  threds = get_thresholds(dataset, i)
  for t in threds:
   # 遍历所有的阈值,计算每个阈值的variance
   v = split_variance(dataset, i, t)
   # 如果v等于None,说明拆分过拟合了,跳过
   if v is None:
    continue
   if v < bestVar:
    bestVar, feature, thred = v, i, t
 # 如果最好的拆分效果达不到要求,那么就不拆分,控制子树的数量
 if var_ - bestVar < MINIMUM_IMPROVE:
  return None, None
 return feature, thred

和上面一样,这个函数当中也用到了一个预剪枝的参数MINIMUM_IMPROVE,它衡量的是每一次生成子树带来的收益。当某一次生成子树带来的收益小于某个值的时候,说明收益很小,并不划算,所以我们就放弃这次子树的生成。这也是预剪枝的一种。

这些都搞定了之后,就可以来建树了。建树的过程和之前类似,只是我们这一次的数据当中没有特征的name,所以我们去掉特征名称的相关逻辑。

def create_decision_tree(dataset):
 dataset = np.array(dataset)
 
 # 如果当前数量小于10,那么就不再继续划分了
 if dataset.shape[0] < MINIMUM_SAMPLES:
  return node_mean(dataset)
 
 # 记录最佳拆分的特征和阈值
 fidx, th = choose_feature_to_split(dataset)
 if fidx is None:
  return th
 
 node = {}
 node['feature'] = fidx
 node['threshold'] = th
 
 # 递归建树
 split_data, vals = split_dataset(dataset, fidx, th)
 for data, val in zip(split_data, vals):
  node[val] = create_decision_tree(data)
 return node

我们来完整测试一下建树,首先我们需要对原始数据进行拆分。将原始数据拆分成训练数据和测试数据,由于我们的场景比较简单,就不设置验证数据了。

拆分数据不用我们自己实现,sklearn当中提供了相应的工具,我们直接调用即可:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)

我们一般用到的参数就两个,一个是test_size,它可以是一个整数也可以是一个浮点数。如果是整数,代表的是测试集的样本数量。如果是一个0-1.0的浮点数,则代表测试集的占比。random_state是生成随机数的时候用到的随机种子。

带你学习Python如何实现回归树模型

我们输出一下生成的树,由于数据量比较大,可以看到一颗庞大的树结构。建树的部分实现了之后,最后剩下的就是预测的部分了。

预测部分的代码和之前分类树相差不大,整体的逻辑完全一样,只是去掉了feature_names的相关逻辑。

def classify(node, data):
 key = node['feature']
 pred = None
 thred = node['threshold']

 if isinstance(node[data[key] < thred], dict):
  pred = classify(node[data[key] < thred], data)
 else:
  pred = node[data[key] < thred]
   
 # 放置pred为空,挑选一个叶子节点作为替补
 if pred is None:
  for key in node:
   if not isinstance(node[key], dict):
    pred = node[key]
    break
 return pred

由于这个函数一次只能接受一条数据,如果我们想要批量预测的话还不行,所以最好的话再实现一个批量预测的predict函数比较好。

def predict(node, X):
 y_pred = []
 for x in X:
  y = classify(node, x)
  y_pred.append(y)
 return np.array(y_pred)

后剪枝

后剪枝的英文原文是post-prune,但是翻译成事后剪枝也有点奇怪。anyway,我们就用后剪枝这个词好了。

在回归树当中,我们利用的思想非常朴素,在建树的时候建立一棵尽量复杂庞大的树。然后在通过测试集对这棵树进行修剪,修剪的逻辑也非常简单,我们判断一棵子树存在分叉和没有分叉单独成为叶子节点时的误差,如果修剪之后误差更小,那么我们就减去这棵子树。

整个剪枝的过程和建树的过程一样,从上到下,递归执行。

整个逻辑很好理解,我们直接来看代码:

def is_dict(node):
 return isinstance(node, dict)


def prune(node, testData):
 testData = np.array(testData)
 if testData.shape[0] == 0:
  return node
 
 # 拆分数据
 split_data, _ = split_dataset(testData, node['feature'], node['threshold'])
 # 对左右子树递归修剪
 if is_dict(node[0]):
  node[0] = prune(node[0], split_data[0])
 if is_dict(node[1]) and len(split_data) > 1:
  node[1] = prune(node[1], split_data[1])

 # 如果左右都是叶子节点,那么判断当前子树是否需要修剪
 if len(split_data) > 1 and not is_dict(node[0]) and not is_dict(node[1]):
  # 计算修剪前的方差和
  baseError = np.sum(np.power(np.array(split_data[0])[:, -1] - node[0], 2)) + np.sum(np.power(np.array(split_data[1])[:, -1] - node[1], 2))
  # 计算修剪后的方差和
  meanVal = (node[0] + node[1]) / 2
  mergeError = np.sum(np.power(meanVal - testData[:, -1], 2))
  if mergeError < baseError:
   return meanVal
  else:
   return node
 return node

最后,我们对修剪之后的效果做一下验证:

带你学习Python如何实现回归树模型

从图中可以看到,修剪之前我们在测试数据上的均方差是19.65,而修剪之后降低到了19.48。从数值上来看是有效果的,只是由于我们的训练数据比较少,同时进行了预剪枝,影响了后剪枝的效果。但是对于实际的机器学习工程来说,一个方法只要是有明确效果的,在代价可以承受的范围内,它就是有价值的,千万不能觉得提升不明显,而随便否定一个方法。

这里计算均方差的时候用到了sklearn当中的一个库函数mean_square_error,从名字当中我们也可以看得出来它的用途,它可以对两个Numpy的array计算均方差

总结

关于回归树模型的相关内容到这里就结束了,我们不仅亲手实现了模型,而且还在真实的数据集上做了实验。如果你是亲手实现的模型的代码,相信你一定会有很多收获。

虽然从实际运用来说我们几乎不会使用树模型来做回归任务,但是回归树模型本身是非常有意义的。因为在它的基础上我们发展出了很多效果更好的模型,比如大名鼎鼎的GBDT。因此理解回归树对于我们后续进阶的学习是非常重要的。在深度学习普及之前,其实大多数高效果的模型都是以树模型为基础的,比如随机森林、GBDT、Adaboost等等。可以说树模型撑起了机器学习的半个时代,这么说相信大家应该都能理解它的重要性了吧。

今天的文章就到这里,如果喜欢本文,可以的话,请点个关注,给我一点鼓励,也方便获取更多文章。

以上就是带你学习Python如何实现回归树模型的详细内容,更多关于Python实现回归树模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python实现的百度站长自动URL提交小工具
Jun 27 Python
分析Python中设计模式之Decorator装饰器模式的要点
Mar 02 Python
python字符串中的单双引
Feb 16 Python
Python3.6.0+opencv3.3.0人脸检测示例
May 25 Python
Python实现全排列的打印
Aug 18 Python
使用Python做定时任务及时了解互联网动态
May 15 Python
如何用C代码给Python写扩展库(Cython)
May 17 Python
详解python实现数据归一化处理的方式:(0,1)标准化
Jul 17 Python
python中字典按键或键值排序的实现代码
Aug 27 Python
python解析yaml文件过程详解
Aug 30 Python
python中with用法讲解
Feb 07 Python
详解用Python调用百度地图正/逆地理编码API
Jul 02 Python
MATLAB数学建模之画图汇总
Jul 16 #Python
浅析Python迭代器的高级用法
Jul 16 #Python
python 使用递归的方式实现语义图片分割功能
Jul 16 #Python
Django serializer优化类视图的实现示例
Jul 16 #Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 #Python
Python实现GIF图倒放
Jul 16 #Python
浅谈python处理json和redis hash的坑
Jul 16 #Python
You might like
php 数组的指针操作实现代码
2011/02/08 PHP
PHP中echo,print_r与var_dump区别分析
2014/09/29 PHP
php常量详细解析
2015/10/27 PHP
微信红包随机生成算法php版
2016/07/21 PHP
利用PHP如何统计Nginx日志的User Agent数据
2019/03/06 PHP
如何取得中文输入的真实长度?
2006/06/24 Javascript
基于Jquery的仿照flash放大图片效果代码
2011/03/16 Javascript
电子商务网站上的常用的js放大镜效果
2011/12/08 Javascript
jQuery实现密保互斥问题解决方案
2013/08/16 Javascript
jquery 倒计时效果实现秒杀思路
2013/09/11 Javascript
JS保留两位小数 四舍五入函数的小例子
2013/11/20 Javascript
IE9+已经不对document.createElement向下兼容的解决方法
2015/09/14 Javascript
js窗口关闭提示信息(兼容IE和firefox)
2015/10/23 Javascript
跟我学习javascript创建对象(类)的8种方法
2015/11/20 Javascript
SublimeText自带格式化代码功能之reindent
2015/12/27 Javascript
json实现添加、遍历与删除属性的方法
2016/06/17 Javascript
nodejs如何获取时间戳与时间差
2016/08/03 NodeJs
解决AngualrJS页面刷新导致异常显示问题
2017/04/20 Javascript
基于iScroll实现下拉刷新和上滑加载效果
2017/07/18 Javascript
JS实现json对象数组按对象属性排序操作示例
2018/05/18 Javascript
[54:33]2018DOTA2亚洲邀请赛小组赛 A组加赛 Liquid vs Optic
2018/04/03 DOTA
python进阶教程之函数参数的多种传递方法
2014/08/30 Python
使用Python编写简单网络爬虫抓取视频下载资源
2014/11/04 Python
Python中的条件判断语句与循环语句用法小结
2016/03/21 Python
python 读取Linux服务器上的文件方法
2018/12/27 Python
在 Jupyter 中重新导入特定的 Python 文件(场景分析)
2019/10/27 Python
python实现门限回归方式
2020/02/29 Python
浅谈CSS3 box-sizing 属性 有趣的盒模型
2019/04/02 HTML / CSS
英国女性时尚品牌:Apricot
2018/12/04 全球购物
运动会广播稿50字
2014/01/26 职场文书
计算机专业毕业生求职信
2014/04/30 职场文书
技能比武方案
2014/05/21 职场文书
企业总经理助理岗位职责
2014/09/12 职场文书
2015年超市工作总结范文
2015/05/26 职场文书
公司管理制度范本
2015/08/03 职场文书
MySQL 分组查询的优化方法
2021/05/12 MySQL