解读python如何实现决策树算法


Posted in Python onOctober 11, 2018

数据描述

每条数据项储存在列表中,最后一列储存结果

多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]

决策树数据结构

class DecisionNode:
  '''决策树节点
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''初始化决策树节点
     
    args:    
    col -- 按数据集的col列划分数据集
    value -- 以value作为划分col列的参照
    result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
    rb,fb -- 代表左右子树
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
  '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
    返回两个数据集
  '''
  split_function=None
  #value是数值类型
  if isinstance(value,int) or isinstance(value,float):
    #定义lambda函数当row[column]>=value时返回true
    split_function=lambda row:row[column]>=value
  #value是字符类型
  else:
    #定义lambda函数当row[column]==value时返回true
    split_function=lambda row:row[column]==value
  #将数据集拆分成两个
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #返回两个数据集
  return (set1,set2)
 
def uniquecounts(rows):
  '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''返回rows数据集的基尼不纯度
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''返回rows数据集的熵
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''构造决策树
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  # 最佳信息增益
  best_gain=0.0
  #
  best_criteria=None
  #最佳划分
  best_sets=None
 
  column_count=len(rows[0])-1
  #遍历数据集的列,确定分割顺序
  for col in range(0,column_count):
    column_values={}
    # 构造字典
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      # 计算信息增益
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  # 如果划分的两个数据集熵不小于原数据集,停止划分
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
  '''对缺失数据进行分类
   
  args:
  observation -- 发生信息缺失的数据项
  tree -- 训练完成的决策树
   
  返回代表该分类的结果字典
  '''
 
  # 判断数据是否到达叶节点
  if tree.results!=None:
    # 已经到达叶节点,返回结果result
    return tree.results
  else:
    # 对数据项的col列进行分析
    v=observation[tree.col]
 
    # 若col列数据缺失
    if v==None:
      #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      # 分别以结果占总数比例计算得到左右子树的权重
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      # 计算左右子树的加权平均
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr的结果k有可能并不在tr中,在result中初始化k
        if k not in result: 
          result[k]=0 
        # fr的结果累加到result中 
        result[k]+=v*fw
      return result
 
    # col列没有缺失,继续沿决策树分类
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))

决策树剪枝

def prune(tree,mingain):
  '''对决策树进行剪枝
   
  args:
  tree -- 决策树
  mingain -- 最小信息增益
   
  返回
  '''
  # 修剪非叶节点
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #合并两个叶子节点
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #计算熵减少情况
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #熵的增加量小于mingain,可以合并分支
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)
Python 相关文章推荐
Python UnicodeEncodeError: 'gbk' codec can't encode character 解决方法
Apr 24 Python
python实现识别相似图片小结
Feb 22 Python
用十张图详解TensorFlow数据读取机制(附代码)
Feb 06 Python
Python实现Dijkstra算法
Oct 17 Python
pandas.DataFrame删除/选取含有特定数值的行或列实例
Nov 07 Python
利用Python库Scapy解析pcap文件的方法
Jul 23 Python
python kafka 多线程消费者&amp;手动提交实例
Dec 21 Python
将python文件打包exe独立运行程序方法详解
Feb 12 Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 Python
matplotlib基础绘图命令之imshow的使用
Aug 13 Python
python判断all函数输出结果是否为true的方法
Dec 03 Python
详解matplotlib绘图样式(style)初探
Feb 03 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 #Python
对python requests的content和text方法的区别详解
Oct 11 #Python
使用pip发布Python程序的方法步骤
Oct 11 #Python
对python Tkinter Text的用法详解
Oct 11 #Python
python数据批量写入ScrolledText的优化方法
Oct 11 #Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 #Python
python中join()方法介绍
Oct 11 #Python
You might like
php更改目录及子目录下所有的文件后缀扩展名的代码
2010/10/12 PHP
PHP如何获取当前主机、域名、网址、路径、端口等参数
2017/06/09 PHP
PHP笛卡尔积实现原理及代码实例
2020/12/09 PHP
XML+XSL 与 HTML 两种方案的结合
2007/04/22 Javascript
利用ASP发送和接收XML数据的处理方法与代码
2007/11/13 Javascript
在html页面上拖放移动标签
2010/01/08 Javascript
jQuery 对Select的操作备忘记录
2011/07/04 Javascript
node在两个div之间移动,用ztree实现
2013/03/06 Javascript
js点击更换背景颜色或图片的实例代码
2013/06/25 Javascript
jquery的each方法使用示例分享
2014/03/25 Javascript
简单讲解AngularJS的Routing路由的定义与使用
2016/03/05 Javascript
Three.js学习之几何形状
2016/08/01 Javascript
jQuery简单实现彩色云标签效果示例
2016/08/01 Javascript
微信小程序 摇一摇抽奖简单实例实现代码
2017/01/09 Javascript
JavaScript简单验证表单空值及邮箱格式的方法
2017/01/20 Javascript
解决Mac node版本升级失败的问题
2018/05/16 Javascript
学习python (1)
2006/10/31 Python
python每次处理固定个数的字符的方法总结
2013/01/29 Python
python实现批量监控网站
2016/09/09 Python
利用Python实现Windows定时关机功能
2017/03/21 Python
Django查询数据库的性能优化示例代码
2017/09/24 Python
Python简单计算给定某一年的某一天是星期几示例
2018/06/27 Python
详解Python3中setuptools、Pip安装教程
2019/06/18 Python
python生成任意频率正弦波方式
2020/02/25 Python
Python过滤掉numpy.array中非nan数据实例
2020/06/08 Python
2021年值得向Python开发者推荐的VS Code扩展插件
2021/01/25 Python
css3实现冲击波效果的示例代码
2018/01/11 HTML / CSS
安全责任书范本
2014/04/15 职场文书
捐助贫困学生倡议书
2014/05/16 职场文书
乡镇党员干部群众路线对照检查材料思想汇报
2014/09/28 职场文书
企业承诺书格式范文
2015/04/28 职场文书
导游词之珠海轮廓
2019/10/25 职场文书
基于python实现银行管理系统
2021/04/20 Python
python 中的@运算符使用
2021/05/26 Python
python中数组和列表的简单实例
2022/03/25 Python
win10电脑双屏显示一个黑屏怎么办?win10电脑双屏显示一个黑屏解决方法
2022/07/15 数码科技