解读python如何实现决策树算法


Posted in Python onOctober 11, 2018

数据描述

每条数据项储存在列表中,最后一列储存结果

多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]

决策树数据结构

class DecisionNode:
  '''决策树节点
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''初始化决策树节点
     
    args:    
    col -- 按数据集的col列划分数据集
    value -- 以value作为划分col列的参照
    result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
    rb,fb -- 代表左右子树
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
  '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
    返回两个数据集
  '''
  split_function=None
  #value是数值类型
  if isinstance(value,int) or isinstance(value,float):
    #定义lambda函数当row[column]>=value时返回true
    split_function=lambda row:row[column]>=value
  #value是字符类型
  else:
    #定义lambda函数当row[column]==value时返回true
    split_function=lambda row:row[column]==value
  #将数据集拆分成两个
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #返回两个数据集
  return (set1,set2)
 
def uniquecounts(rows):
  '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''返回rows数据集的基尼不纯度
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''返回rows数据集的熵
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''构造决策树
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  # 最佳信息增益
  best_gain=0.0
  #
  best_criteria=None
  #最佳划分
  best_sets=None
 
  column_count=len(rows[0])-1
  #遍历数据集的列,确定分割顺序
  for col in range(0,column_count):
    column_values={}
    # 构造字典
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      # 计算信息增益
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  # 如果划分的两个数据集熵不小于原数据集,停止划分
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
  '''对缺失数据进行分类
   
  args:
  observation -- 发生信息缺失的数据项
  tree -- 训练完成的决策树
   
  返回代表该分类的结果字典
  '''
 
  # 判断数据是否到达叶节点
  if tree.results!=None:
    # 已经到达叶节点,返回结果result
    return tree.results
  else:
    # 对数据项的col列进行分析
    v=observation[tree.col]
 
    # 若col列数据缺失
    if v==None:
      #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      # 分别以结果占总数比例计算得到左右子树的权重
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      # 计算左右子树的加权平均
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr的结果k有可能并不在tr中,在result中初始化k
        if k not in result: 
          result[k]=0 
        # fr的结果累加到result中 
        result[k]+=v*fw
      return result
 
    # col列没有缺失,继续沿决策树分类
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))

决策树剪枝

def prune(tree,mingain):
  '''对决策树进行剪枝
   
  args:
  tree -- 决策树
  mingain -- 最小信息增益
   
  返回
  '''
  # 修剪非叶节点
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #合并两个叶子节点
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #计算熵减少情况
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #熵的增加量小于mingain,可以合并分支
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)
Python 相关文章推荐
Python实现网站文件的全备份和差异备份
Nov 30 Python
python 如何快速找出两个电子表中数据的差异
May 26 Python
PyQt5每天必学之滑块控件QSlider
Apr 20 Python
Sanic框架基于类的视图用法示例
Jul 18 Python
解决python写入带有中文的字符到文件错误的问题
Jan 31 Python
python将类似json的数据存储到MySQL中的实例
Jul 12 Python
Python 异常的捕获、异常的传递与主动抛出异常操作示例
Sep 23 Python
解决Keyerror ''acc'' KeyError: ''val_acc''问题
Jun 18 Python
Python爬取你好李焕英豆瓣短评生成词云的示例代码
Feb 24 Python
opencv实现图像平移效果
Mar 24 Python
秀!学妹看见都惊呆的Python小招数!【详细语言特性使用技巧】
Apr 27 Python
Python实现学生管理系统并生成exe可执行文件详解流程
Jan 22 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 #Python
对python requests的content和text方法的区别详解
Oct 11 #Python
使用pip发布Python程序的方法步骤
Oct 11 #Python
对python Tkinter Text的用法详解
Oct 11 #Python
python数据批量写入ScrolledText的优化方法
Oct 11 #Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 #Python
python中join()方法介绍
Oct 11 #Python
You might like
也谈截取首页新闻 - 范例
2006/10/09 PHP
php获取mysql版本的几种方法小结
2008/03/25 PHP
php实现网站插件机制的方法
2009/11/10 PHP
给ECShop添加最新评论
2015/01/07 PHP
PHP判断FORM表单或URL参数来的数据是否为整数的方法
2016/03/25 PHP
PHP微信企业号开发之回调模式开启与用法示例
2017/11/25 PHP
javascript 带有滚动条的表格,标题固定,带排序功能.
2009/11/13 Javascript
jQuery的一些特性和用法整理小结
2010/01/13 Javascript
jQuery学习笔记之jQuery的事件
2010/12/22 Javascript
用jquery生成二级菜单的实例代码
2013/06/24 Javascript
table对象中的insertRow与deleteRow使用示例
2014/01/26 Javascript
jQuery回调函数的定义及用法实例
2014/12/23 Javascript
每天一篇javascript学习小结(Date对象)
2015/11/13 Javascript
javascript十六进制数字和ASCII字符之间的转换方法
2016/12/27 Javascript
bootstrap table 数据表格行内修改的实现代码
2017/02/13 Javascript
jQuery Form插件使用详解_动力节点Java学院整理
2017/07/17 jQuery
JS实现页面打印(整体、局部)
2017/08/18 Javascript
深入浅析Vue.js中 computed和methods不同机制
2018/03/22 Javascript
[54:53]完美世界DOTA2联赛PWL S2 GXR vs PXG 第二场 11.18
2020/11/18 DOTA
[48:51]完美世界DOTA2联赛PWL S2 Magma vs InkIce 第一场 11.28
2020/12/02 DOTA
简单介绍Python的轻便web框架Bottle
2015/04/08 Python
python利用matplotlib库绘制饼图的方法示例
2016/12/18 Python
Python学习小技巧之列表项的推导式与过滤操作
2017/05/20 Python
Python3实现的画图及加载图片动画效果示例
2018/01/19 Python
python3+PyQt5实现柱状图
2018/04/24 Python
TensorFlow入门使用 tf.train.Saver()保存模型
2018/04/24 Python
浅析python中while循环和for循环
2019/11/19 Python
Pycharm激活方法及详细教程(详细且实用)
2020/05/12 Python
递归计算如下递归函数的值(斐波拉契)
2012/02/04 面试题
师范生自我鉴定范文
2013/10/05 职场文书
个人求职信范文分享
2014/01/31 职场文书
《画杨桃》教学反思
2014/04/13 职场文书
卫生保健工作总结2015
2015/05/18 职场文书
初中思想品德教学反思
2016/02/24 职场文书
【DOTA2】高能暴走TK秀!PSG LGD vs ASTER - DPC 2022 WINTER TOUR CN
2022/04/02 DOTA
SQL中的连接查询详解
2022/06/21 SQL Server