解读python如何实现决策树算法


Posted in Python onOctober 11, 2018

数据描述

每条数据项储存在列表中,最后一列储存结果

多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]

决策树数据结构

class DecisionNode:
  '''决策树节点
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''初始化决策树节点
     
    args:    
    col -- 按数据集的col列划分数据集
    value -- 以value作为划分col列的参照
    result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
    rb,fb -- 代表左右子树
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
  '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
    返回两个数据集
  '''
  split_function=None
  #value是数值类型
  if isinstance(value,int) or isinstance(value,float):
    #定义lambda函数当row[column]>=value时返回true
    split_function=lambda row:row[column]>=value
  #value是字符类型
  else:
    #定义lambda函数当row[column]==value时返回true
    split_function=lambda row:row[column]==value
  #将数据集拆分成两个
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #返回两个数据集
  return (set1,set2)
 
def uniquecounts(rows):
  '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''返回rows数据集的基尼不纯度
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''返回rows数据集的熵
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''构造决策树
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  # 最佳信息增益
  best_gain=0.0
  #
  best_criteria=None
  #最佳划分
  best_sets=None
 
  column_count=len(rows[0])-1
  #遍历数据集的列,确定分割顺序
  for col in range(0,column_count):
    column_values={}
    # 构造字典
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      # 计算信息增益
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  # 如果划分的两个数据集熵不小于原数据集,停止划分
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
  '''对缺失数据进行分类
   
  args:
  observation -- 发生信息缺失的数据项
  tree -- 训练完成的决策树
   
  返回代表该分类的结果字典
  '''
 
  # 判断数据是否到达叶节点
  if tree.results!=None:
    # 已经到达叶节点,返回结果result
    return tree.results
  else:
    # 对数据项的col列进行分析
    v=observation[tree.col]
 
    # 若col列数据缺失
    if v==None:
      #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      # 分别以结果占总数比例计算得到左右子树的权重
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      # 计算左右子树的加权平均
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr的结果k有可能并不在tr中,在result中初始化k
        if k not in result: 
          result[k]=0 
        # fr的结果累加到result中 
        result[k]+=v*fw
      return result
 
    # col列没有缺失,继续沿决策树分类
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))

决策树剪枝

def prune(tree,mingain):
  '''对决策树进行剪枝
   
  args:
  tree -- 决策树
  mingain -- 最小信息增益
   
  返回
  '''
  # 修剪非叶节点
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #合并两个叶子节点
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #计算熵减少情况
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #熵的增加量小于mingain,可以合并分支
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)
Python 相关文章推荐
python 布尔操作实现代码
Mar 23 Python
win7上python2.7连接mysql数据库的方法
Jan 14 Python
Python实现按学生年龄排序的实际问题详解
Aug 29 Python
一文了解Python并发编程的工程实现方法
May 31 Python
python numpy 反转 reverse示例
Dec 04 Python
centos7中安装python3.6.4的教程
Dec 11 Python
Python自动化测试笔试面试题精选
Mar 12 Python
完美解决pyinstaller打包报错找不到依赖pypiwin32或pywin32-ctypes的错误
Apr 01 Python
python实现126邮箱发送邮件
May 20 Python
vscode+PyQt5安装详解步骤
Aug 12 Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 Python
python通过新建环境安装tfx的问题
May 20 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 #Python
对python requests的content和text方法的区别详解
Oct 11 #Python
使用pip发布Python程序的方法步骤
Oct 11 #Python
对python Tkinter Text的用法详解
Oct 11 #Python
python数据批量写入ScrolledText的优化方法
Oct 11 #Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 #Python
python中join()方法介绍
Oct 11 #Python
You might like
PHP与jquery实时显示网站在线人数实例详解
2016/12/02 PHP
PHP实现的简单异常处理类示例
2017/05/04 PHP
JAVASCRIPT对象及属性
2007/02/13 Javascript
JavaScript学习小结(一)——JavaScript入门基础
2015/09/02 Javascript
JQuery+Ajax实现数据查询、排序和分页功能
2015/09/27 Javascript
jQuery插件cxSelect多级联动下拉菜单实例解析
2016/06/24 Javascript
使用JavaScript获取Request中参数的值方法
2016/09/27 Javascript
让浏览器崩溃的12行JS代码(DoS攻击分析及防御)
2016/10/10 Javascript
JS/jquery实现一个网页内同时调用多个倒计时的方法
2017/04/27 jQuery
使用jQuery实现鼠标点击左右按钮滑动切换
2017/08/04 jQuery
详解a++和++a的区别
2017/08/30 Javascript
详解layui弹窗父子窗口之间传参数的方法
2018/01/16 Javascript
javaScript强制保留两位小数的输入数校验和小数保留问题
2018/05/09 Javascript
JS实现的简单下拉框联动功能示例
2018/05/11 Javascript
详解js静态检查工具eslint配置文件
2018/11/23 Javascript
javascript刷新父页面方法汇总详解
2019/10/10 Javascript
vue 获取元素额外生成的data-v-xxx操作
2020/09/09 Javascript
vue form表单post请求结合Servlet实现文件上传功能
2021/01/22 Vue.js
[01:03:59]2018DOTA2亚洲邀请赛3月30日 小组赛B组VGJ.T VS Secret
2018/03/31 DOTA
[44:58]2018DOTA2亚洲邀请赛 4.5 淘汰赛 LGD vs Liquid 第二场
2018/04/06 DOTA
python list语法学习(带例子)
2013/11/01 Python
Python文件读取的3种方法及路径转义
2015/06/21 Python
python实现拓扑排序的基本教程
2018/03/11 Python
解决python中使用PYQT时中文乱码问题
2019/06/17 Python
使用python进行广告点击率的预测的实现
2019/07/04 Python
python cv2读取rtsp实时码流按时生成连续视频文件方式
2019/12/25 Python
python读取与处理netcdf数据方式
2020/02/14 Python
python如何爬取网页中的文字
2020/07/28 Python
快速实现一个简单的canvas迷宫游戏的示例
2018/07/04 HTML / CSS
美国特价机票专家:Airfarewatchdog
2018/01/24 全球购物
爱尔兰电脑、家电和家具购物网站:Buy It Direct
2019/07/09 全球购物
品学兼优的大学生自我评价
2013/09/20 职场文书
2015年试用期工作总结
2014/12/12 职场文书
台风停课通知
2015/04/24 职场文书
导游词之天津盘山
2019/11/01 职场文书
《火纹风花雪月无双》预告“神秘雇佣兵” 紫发剑客
2022/04/13 其他游戏