解读python如何实现决策树算法


Posted in Python onOctober 11, 2018

数据描述

每条数据项储存在列表中,最后一列储存结果

多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]

决策树数据结构

class DecisionNode:
  '''决策树节点
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''初始化决策树节点
     
    args:    
    col -- 按数据集的col列划分数据集
    value -- 以value作为划分col列的参照
    result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
    rb,fb -- 代表左右子树
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
  '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
    返回两个数据集
  '''
  split_function=None
  #value是数值类型
  if isinstance(value,int) or isinstance(value,float):
    #定义lambda函数当row[column]>=value时返回true
    split_function=lambda row:row[column]>=value
  #value是字符类型
  else:
    #定义lambda函数当row[column]==value时返回true
    split_function=lambda row:row[column]==value
  #将数据集拆分成两个
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #返回两个数据集
  return (set1,set2)
 
def uniquecounts(rows):
  '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''返回rows数据集的基尼不纯度
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''返回rows数据集的熵
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''构造决策树
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  # 最佳信息增益
  best_gain=0.0
  #
  best_criteria=None
  #最佳划分
  best_sets=None
 
  column_count=len(rows[0])-1
  #遍历数据集的列,确定分割顺序
  for col in range(0,column_count):
    column_values={}
    # 构造字典
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      # 计算信息增益
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  # 如果划分的两个数据集熵不小于原数据集,停止划分
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
  '''对缺失数据进行分类
   
  args:
  observation -- 发生信息缺失的数据项
  tree -- 训练完成的决策树
   
  返回代表该分类的结果字典
  '''
 
  # 判断数据是否到达叶节点
  if tree.results!=None:
    # 已经到达叶节点,返回结果result
    return tree.results
  else:
    # 对数据项的col列进行分析
    v=observation[tree.col]
 
    # 若col列数据缺失
    if v==None:
      #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      # 分别以结果占总数比例计算得到左右子树的权重
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      # 计算左右子树的加权平均
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr的结果k有可能并不在tr中,在result中初始化k
        if k not in result: 
          result[k]=0 
        # fr的结果累加到result中 
        result[k]+=v*fw
      return result
 
    # col列没有缺失,继续沿决策树分类
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))

决策树剪枝

def prune(tree,mingain):
  '''对决策树进行剪枝
   
  args:
  tree -- 决策树
  mingain -- 最小信息增益
   
  返回
  '''
  # 修剪非叶节点
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #合并两个叶子节点
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #计算熵减少情况
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #熵的增加量小于mingain,可以合并分支
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)
Python 相关文章推荐
python生成指定尺寸缩略图的示例
May 07 Python
python 对象和json互相转换方法
Mar 22 Python
python发送邮件脚本
May 22 Python
python中的二维列表实例详解
Jun 19 Python
Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例
Jul 27 Python
如何在django里上传csv文件并进行入库处理的方法
Jan 02 Python
python utc datetime转换为时间戳的方法
Jan 15 Python
使用PyTorch将文件夹下的图片分为训练集和验证集实例
Jan 08 Python
python实现PCA降维的示例详解
Feb 24 Python
Django中的AutoField字段使用
May 18 Python
安装pyecharts1.8.0版本后导入pyecharts模块绘图时报错: “所有图表类型将在 v1.9.0 版本开始强制使用 ChartItem 进行数据项配置 ”的解决方法
Aug 18 Python
为2021年的第一场雪锦上添花:用matplotlib绘制雪花和雪景
Jan 05 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 #Python
对python requests的content和text方法的区别详解
Oct 11 #Python
使用pip发布Python程序的方法步骤
Oct 11 #Python
对python Tkinter Text的用法详解
Oct 11 #Python
python数据批量写入ScrolledText的优化方法
Oct 11 #Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 #Python
python中join()方法介绍
Oct 11 #Python
You might like
星际原理概述
2020/03/04 星际争霸
php常用的安全过滤函数集锦
2014/10/09 PHP
PHP实现自动识别原编码并对字符串进行编码转换的方法
2016/07/13 PHP
php cookie用户登录的详解及实例代码
2017/01/03 PHP
PHP数字金额转换成中文大写显示
2019/01/05 PHP
基于laravel Request的所有方法详解
2019/09/29 PHP
讲两件事:1.this指针的用法小探. 2.ie的attachEvent和firefox的addEventListener在事件处理上的区别
2007/04/12 Javascript
js操作Xml(向服务器发送Xml,处理服务器返回的Xml)(IE下有效)
2009/01/30 Javascript
JQuery中extend的用法实例分析
2015/02/08 Javascript
JavaScript控制table某列不显示的方法
2015/03/16 Javascript
JS提交form表单实例分析
2015/12/10 Javascript
Javascript实现图片轮播效果(二)图片序列节点的控制实现
2016/02/17 Javascript
javascript的几种写法总结
2016/09/30 Javascript
jQuery插件扩展操作入门示例
2017/01/16 Javascript
深入理解vue.js中$watch的oldvalue与newValue
2017/08/07 Javascript
Vue实现active点击切换方法
2018/03/16 Javascript
基于jQuery使用Ajax动态执行模糊查询功能
2018/07/05 jQuery
Jquery的autocomplete插件用法及参数讲解
2019/03/12 jQuery
Javascript和jquery在selenium的使用过程
2019/10/31 jQuery
在Python的框架中为MySQL实现restful接口的教程
2015/04/08 Python
使用python list 查找所有匹配元素的位置实例
2019/06/11 Python
pyqt5 使用cv2 显示图片,摄像头的实例
2019/06/27 Python
详解django实现自定义manage命令的扩展
2019/08/13 Python
使用 Python 合并多个格式一致的 Excel 文件(推荐)
2019/12/09 Python
Python对称的二叉树多种思路实现方法
2020/02/28 Python
python opencv肤色检测的实现示例
2020/12/21 Python
高三自我鉴定
2013/10/23 职场文书
销售实习自我鉴定
2013/12/07 职场文书
老同学聚会感言
2014/02/23 职场文书
文明城市创建标语
2014/06/16 职场文书
2014年党员自我评议总结
2014/09/23 职场文书
员工辞职信范文
2015/03/02 职场文书
2015年度党员自我评价范文
2015/03/03 职场文书
标准发言稿结尾
2019/07/18 职场文书
python简单验证码识别的实现过程
2021/06/20 Python
Python函数中apply、map、applymap的区别
2021/11/27 Python