python机器学习实战之树回归详解


Posted in Python onDecember 20, 2017

本文实例为大家分享了树回归的具体代码,供大家参考,具体内容如下

#-*- coding:utf-8 -*- 
#!/usr/bin/python 
''''' 
回归树  连续值回归预测 的 回归树 
''' 
# 测试代码 
# import regTrees as RT RT.RtTreeTest() RT.RtTreeTest('ex0.txt') RT.RtTreeTest('ex2.txt') 
# import regTrees as RT RT.RtTreeTest('ex2.txt',ops=(10000,4)) 
# import regTrees as RT RT.pruneTest() 
# 模型树 测试 
# import regTrees as RT RT.modeTreeTest(ops=(1,10) 
# 模型回归树和普通回归树 效果比较 计算相关系数  
# import regTrees as RT RT.MRTvsSRT() 
from numpy import * 
 
 
# Tab 键值分隔的数据 提取成 列表数据集 成浮点型数据 
def loadDataSet(fileName):   #   
  dataMat = []        # 目标数据集 列表 
  fr = open(fileName) 
  for line in fr.readlines(): 
    curLine = line.strip().split('\t') 
    fltLine = map(float,curLine) #转换成浮点型数据 
    dataMat.append(fltLine) 
  return dataMat 
 
# 按特征值 的数据集二元切分  特征(列)  对应的值 
# 某一列的值大于value值的一行样本全部放在一个矩阵里,其余放在另一个矩阵里 
def binSplitDataSet(dataSet, feature, value): 
  mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0] # 数组过滤 
  mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0] #  
  return mat0,mat1 
 
# 常量叶子节点 
def regLeaf(dataSet):# 最后一列为标签 为数的叶子节点 
  return mean(dataSet[:,-1])# 目标变量的均值 
# 方差 
def regErr(dataSet): 
  return var(dataSet[:,-1]) * shape(dataSet)[0]# 目标变量的平方误差 * 样本个数(行数)的得到总方差 
 
# 选择最优的 分裂属性和对应的大小 
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): 
  tolS = ops[0] # 允许的误差下降值 
  tolN = ops[1] # 切分的最少样本数量 
  if len(set(dataSet[:,-1].T.tolist()[0])) == 1: # 特征剩余数量为1 则返回 
    return None, leafType(dataSet)       #### 返回 1 ####  
  m,n = shape(dataSet) # 当前数据集大小 形状 
  S = errType(dataSet) # 当前数据集误差 均方误差 
  bestS = inf; bestIndex = 0; bestValue = 0 
  for featIndex in range(n-1):# 遍历 可分裂特征 
    for splitVal in set(dataSet[:,featIndex]):# 遍历对应 特性的 属性值 
      mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)# 进行二元分割 
      if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue #样本数量 小于设定值,则不切分 
      newS = errType(mat0) + errType(mat1)# 二元分割后的 均方差 
      if newS < bestS: # 弱比分裂前小 则保留这个分类 
        bestIndex = featIndex 
        bestValue = splitVal 
        bestS = newS 
  if (S - bestS) < tolS: # 弱分裂后 比 分裂前样本方差 减小的不多 也不进行切分 
    return None, leafType(dataSet)       #### 返回 2 ####  
  mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) 
  if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #样本数量 小于设定值,则不切分 
    return None, leafType(dataSet)       #### 返回 3 ####  
  return bestIndex,bestValue # 返回最佳的 分裂属性 和 对应的值 
 
# 创建回归树 numpy数组数据集 叶子函数  误差函数  用户设置参数(最小样本数量 以及最小误差下降间隔) 
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): 
 # 找到最佳的待切分特征和对应 的值 
  feat, val = chooseBestSplit(dataSet, leafType, errType, ops)# 
 # 停止条件 该节点不能再分,该节点为叶子节点 
  if feat == None: return val  
  retTree = {} 
  retTree['spInd'] = feat #特征 
  retTree['spVal'] = val #值 
 # 执行二元切分  
  lSet, rSet = binSplitDataSet(dataSet, feat, val)# 二元切分 左树 右树 
 # 创建左树 
  retTree['left'] = createTree(lSet, leafType, errType, ops)  # 左树 最终返回子叶子节点 的属性值 
 # 创建右树 
  retTree['right'] = createTree(rSet, leafType, errType, ops) # 右树 
  return retTree  
 
# 未进行后剪枝的回归树测试  
def RtTreeTest(filename='ex00.txt',ops=(1,4)): 
  MyDat = loadDataSet(filename) # ex00.txt y = w*x 两维  ex0.txt y = w*x+b 三维 
  MyMat = mat(MyDat) 
  print createTree(MyMat,ops=ops) 
# 判断是不是树 (按字典形式存储) 
def isTree(obj): 
  return (type(obj).__name__=='dict') 
 
# 返回树的平均值 塌陷处理 
def getMean(tree): 
  if isTree(tree['right']):  
  tree['right'] = getMean(tree['right']) 
  if isTree(tree['left']):  
  tree['left'] = getMean(tree['left']) 
  return (tree['left']+tree['right'])/2.0 # 两个叶子节点的 平均值 
 
# 后剪枝  待剪枝的树  剪枝所需的测试数据 
def prune(tree, testData): 
  if shape(testData)[0] == 0:  
  return getMean(tree) #没有测试数据 返回 
  if (isTree(tree['right']) or isTree(tree['left'])): # 如果回归树的左右两边是树 
    lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])#对测试数据 进行切分 
  if isTree(tree['left']):  
  tree['left'] = prune(tree['left'], lSet)  # 对左树进行剪枝 
  if isTree(tree['right']):  
  tree['right'] = prune(tree['right'], rSet)# 对右树进行剪枝 
  if not isTree(tree['left']) and not isTree(tree['right']):#两边都是叶子 
    lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])#对测试数据 进行切分 
    errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\ 
      sum(power(rSet[:,-1] - tree['right'],2)) # 对两边叶子合并前计算 误差  
    treeMean = (tree['left']+tree['right'])/2.0 # 合并后的 叶子 均值 
    errorMerge = sum(power(testData[:,-1] - treeMean,2))# 合并后 的误差 
    if errorMerge < errorNoMerge: # 合并后的误差小于合并前的误差 
      print "merging"      # 说明合并后的树 误差更小 
      return treeMean      # 返回两个叶子 的均值 作为 合并后的叶子节点 
    else: return tree 
  else: return tree 
   
def pruneTest(): 
  MyDat = loadDataSet('ex2.txt')  
  MyMat = mat(MyDat) 
  MyTree = createTree(MyMat,ops=(0,1))  # 为了得到 最大的树 误差设置为0 个数设置为1 即不进行预剪枝 
  MyDatTest = loadDataSet('ex2test.txt') 
  MyMatTest = mat(MyDatTest) 
  print prune(MyTree,MyMatTest) 
 
 
######叶子节点为线性模型的模型树######### 
# 线性模型 
def linearSolve(dataSet):   
  m,n = shape(dataSet) # 数据集大小 
  X = mat(ones((m,n))) # 自变量 
  Y = mat(ones((m,1))) # 目标变量  
  X[:,1:n] = dataSet[:,0:n-1]# 样本数据集合 
  Y = dataSet[:,-1]     # 标签 
  # 线性模型 求解 
  xTx = X.T*X         
  if linalg.det(xTx) == 0.0: 
    raise NameError('行列式值为零,不能计算逆矩阵,可适当增加ops的第二个值') 
  ws = xTx.I * (X.T * Y) 
  return ws,X,Y 
 
# 模型叶子节点 
def modelLeaf(dataSet):  
  ws,X,Y = linearSolve(dataSet) 
  return ws 
 
# 计算模型误差 
def modelErr(dataSet): 
  ws,X,Y = linearSolve(dataSet) 
  yHat = X * ws 
  return sum(power(Y - yHat,2)) 
 
# 模型树测试 
def modeTreeTest(filename='ex2.txt',ops=(1,4)): 
  MyDat = loadDataSet(filename) #  
  MyMat = mat(MyDat) 
  print createTree(MyMat,leafType=modelLeaf, errType=modelErr,ops=ops)#带入线性模型 和相应 的误差计算函数 
 
 
# 模型效果计较 
# 线性叶子节点 预测计算函数 直接返回 树叶子节点 值 
def regTreeEval(model, inDat): 
  return float(model) 
 
def modelTreeEval(model, inDat): 
  n = shape(inDat)[1] 
  X = mat(ones((1,n+1)))# 增加一列 
  X[:,1:n+1]=inDat 
  return float(X*model) # 返回 值乘以 线性回归系数 
 
# 树预测函数 
def treeForeCast(tree, inData, modelEval=regTreeEval): 
  if not isTree(tree):  
  return modelEval(tree, inData) # 返回 叶子节点 预测值 
  if inData[tree['spInd']] > tree['spVal']:   # 左树 
    if isTree(tree['left']):  
    return treeForeCast(tree['left'], inData, modelEval)# 还是树 则递归调用 
    else:  
    return modelEval(tree['left'], inData) # 计算叶子节点的值 并返回 
  else: 
    if isTree(tree['right']):         # 右树 
    return treeForeCast(tree['right'], inData, modelEval) 
    else:  
    return modelEval(tree['right'], inData)# 计算叶子节点的值 并返回 
 
# 得到预测值     
def createForeCast(tree, testData, modelEval=regTreeEval): 
  m=len(testData) 
  yHat = mat(zeros((m,1)))#预测标签 
  for i in range(m): 
    yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval) 
  return yHat 
 
# 常量回归树和线性模型回归树的预测结果比较 
def MRTvsSRT(): 
  TestMat = mat(loadDataSet('bikeSpeedVsIq_test.txt')) 
  TrainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt')) 
# 普通回归树 预测结果 
  # 得到普通回归树树 
  StaTree = createTree(TrainMat, ops=(1,20)) 
  # 得到预测结果 
  StaYHat = createForeCast(StaTree, TestMat[:,0], regTreeEval)# 第一列为 自变量 
  # 预测结果和真实标签的相关系数 
  StaCorr = corrcoef(StaYHat, TestMat[:,1], rowvar=0)[0,1] # NumPy 库函数  
# 模型回归树 预测结果 
  # 得到模型回归树 
  ModeTree = createTree(TrainMat,leafType=modelLeaf, errType=modelErr, ops=(1,20)) 
  # 得到预测结果 
  ModeYHat = createForeCast(ModeTree, TestMat[:,0], modelTreeEval)  
  # 预测结果和真实标签的相关系数 
  ModeCorr = corrcoef(ModeYHat, TestMat[:,1], rowvar=0)[0,1] # NumPy 库函数   
  print "普通回归树 预测结果的相关系数R2: %f" %(StaCorr)                        
  print "模型回归树 预测结果的相关系数R2: %f" %(ModeCorr) 
  if ModeCorr>StaCorr: 
  print "模型回归树效果优于普通回归树" 
  else: 
  print "回归回归树效果优于模型普通树"

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3实现简单可学习的手写体识别(实例讲解)
Oct 21 Python
Pandas 对Dataframe结构排序的实现方法
Apr 10 Python
详解Python3.6安装psutil模块和功能简介
May 30 Python
Python 从列表中取值和取索引的方法
Dec 25 Python
深入了解Python枚举类型的相关知识
Jul 09 Python
django 微信网页授权认证api的步骤详解
Jul 30 Python
Python提取PDF内容的方法(文本、图像、线条等)
Sep 25 Python
numpy ndarray 取出满足特定条件的某些行实例
Dec 05 Python
pytorch下大型数据集(大型图片)的导入方式
Jan 08 Python
python3中确保枚举值代码分析
Dec 02 Python
python 写一个文件分发小程序
Dec 05 Python
Pandas 数据编码的十种方法
Apr 20 Python
使用python 和 lint 删除项目无用资源的方法
Dec 20 #Python
python机器学习实战之K均值聚类
Dec 20 #Python
Python绘制3d螺旋曲线图实例代码
Dec 20 #Python
python机器学习实战之最近邻kNN分类器
Dec 20 #Python
python3.6 +tkinter GUI编程 实现界面化的文本处理工具(推荐)
Dec 20 #Python
浅谈Python实现Apriori算法介绍
Dec 20 #Python
利用Python如何生成hash值示例详解
Dec 20 #Python
You might like
完善CodeIgniter在IDE中代码提示功能的方法
2014/07/19 PHP
PHP如何将log信息写入服务器中的log文件
2015/07/29 PHP
Ext.get() 和 Ext.query()组合使用实现最灵活的取元素方式
2011/09/26 Javascript
js确定对象类型方法
2012/03/30 Javascript
js关闭浏览器窗口及检查浏览器关闭事件
2013/09/03 Javascript
分享10个原生JavaScript技巧
2015/04/20 Javascript
详解AngularJS过滤器的使用
2016/03/11 Javascript
vue-router项目实战总结篇
2018/02/11 Javascript
vue 子组件向父组件传值方法
2018/02/26 Javascript
浅谈Vue下使用百度地图的简易方法
2018/03/23 Javascript
vue用递归组件写树形控件的实例代码
2018/07/19 Javascript
Vue 框架之键盘事件、健值修饰符、双向数据绑定
2018/11/14 Javascript
vue-router懒加载速度缓慢问题及解决方法
2018/11/25 Javascript
详解小程序用户登录状态检查与更新实例
2019/05/15 Javascript
PHP 502bad gateway原因及解决方案
2020/11/13 Javascript
详解python 发送邮件实例代码
2016/12/22 Python
Python with语句上下文管理器两种实现方法分析
2018/02/09 Python
Python机器学习k-近邻算法(K Nearest Neighbor)实例详解
2018/06/25 Python
NumPy 数学函数及代数运算的实现代码
2018/07/18 Python
python GUI库图形界面开发之PyQt5信号与槽的高级使用技巧(自定义信号与槽)详解与实例
2020/03/06 Python
Python Tornado批量上传图片并显示功能
2020/03/26 Python
Python如何使用vars返回对象的属性列表
2020/10/17 Python
用CSS3将你的设计带入下个高度
2009/08/08 HTML / CSS
html5实现完美兼容各大浏览器的播放器
2014/12/26 HTML / CSS
Lungolivigno Fashion官网:高级时装在线购物
2020/10/17 全球购物
项目合作计划书
2014/01/09 职场文书
写给女朋友的检讨书
2014/01/28 职场文书
劳资协议书范本
2014/04/23 职场文书
我们的节日中秋活动方案
2014/08/19 职场文书
党的群众路线教育实践活动查摆问题自查报告
2014/10/10 职场文书
PHP使用非对称加密算法RSA
2021/04/21 PHP
浅析Python中的套接字编程
2021/06/22 Python
Redis RDB技术底层原理详解
2021/09/04 Redis
springboot 多数据源配置不生效遇到的坑及解决
2021/11/17 Java/Android
SQL优化老出错,那是你没弄明白MySQL解释计划用法
2021/11/27 MySQL
Java 超详细讲解十大排序算法面试无忧
2022/04/08 Java/Android