解读python如何实现决策树算法


Posted in Python onOctober 11, 2018

数据描述

每条数据项储存在列表中,最后一列储存结果

多条数据项形成数据集

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]

决策树数据结构

class DecisionNode:
  '''决策树节点
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''初始化决策树节点
     
    args:    
    col -- 按数据集的col列划分数据集
    value -- 以value作为划分col列的参照
    result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果':结果出现次数}
    rb,fb -- 代表左右子树
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果':结果出现次数}的方式表达每个子集

def pideset(rows,column,value):
  '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
    返回两个数据集
  '''
  split_function=None
  #value是数值类型
  if isinstance(value,int) or isinstance(value,float):
    #定义lambda函数当row[column]>=value时返回true
    split_function=lambda row:row[column]>=value
  #value是字符类型
  else:
    #定义lambda函数当row[column]==value时返回true
    split_function=lambda row:row[column]==value
  #将数据集拆分成两个
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #返回两个数据集
  return (set1,set2)
 
def uniquecounts(rows):
  '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''返回rows数据集的基尼不纯度
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''返回rows数据集的熵
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''构造决策树
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  # 最佳信息增益
  best_gain=0.0
  #
  best_criteria=None
  #最佳划分
  best_sets=None
 
  column_count=len(rows[0])-1
  #遍历数据集的列,确定分割顺序
  for col in range(0,column_count):
    column_values={}
    # 构造字典
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      # 计算信息增益
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  # 如果划分的两个数据集熵不小于原数据集,停止划分
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
    draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

def mdclassify(observation,tree):
  '''对缺失数据进行分类
   
  args:
  observation -- 发生信息缺失的数据项
  tree -- 训练完成的决策树
   
  返回代表该分类的结果字典
  '''
 
  # 判断数据是否到达叶节点
  if tree.results!=None:
    # 已经到达叶节点,返回结果result
    return tree.results
  else:
    # 对数据项的col列进行分析
    v=observation[tree.col]
 
    # 若col列数据缺失
    if v==None:
      #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      # 分别以结果占总数比例计算得到左右子树的权重
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      # 计算左右子树的加权平均
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr的结果k有可能并不在tr中,在result中初始化k
        if k not in result: 
          result[k]=0 
        # fr的结果累加到result中 
        result[k]+=v*fw
      return result
 
    # col列没有缺失,继续沿决策树分类
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))

决策树剪枝

def prune(tree,mingain):
  '''对决策树进行剪枝
   
  args:
  tree -- 决策树
  mingain -- 最小信息增益
   
  返回
  '''
  # 修剪非叶节点
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #合并两个叶子节点
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #计算熵减少情况
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #熵的增加量小于mingain,可以合并分支
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)
Python 相关文章推荐
简单了解Python下用于监视文件系统的pyinotify包
Nov 13 Python
利用Python读取文件的四种不同方法比对
May 18 Python
Python实现在线暴力破解邮箱账号密码功能示例【测试可用】
Sep 06 Python
Python之Scrapy爬虫框架安装及简单使用详解
Dec 22 Python
Python并行分布式框架Celery详解
Oct 15 Python
Python udp网络程序实现发送、接收数据功能示例
Dec 09 Python
PIL包中Image模块的convert()函数的具体使用
Feb 26 Python
Python如何使用队列方式实现多线程爬虫
May 12 Python
如何利用python进行时间序列分析
Aug 04 Python
PyTorch如何搭建一个简单的网络
Aug 24 Python
class类在python中获取金融数据的实例方法
Dec 10 Python
使用pandas生成/读取csv文件的方法实例
Jul 09 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 #Python
对python requests的content和text方法的区别详解
Oct 11 #Python
使用pip发布Python程序的方法步骤
Oct 11 #Python
对python Tkinter Text的用法详解
Oct 11 #Python
python数据批量写入ScrolledText的优化方法
Oct 11 #Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 #Python
python中join()方法介绍
Oct 11 #Python
You might like
用IE远程创建Mysql数据库的简易程序
2006/10/09 PHP
MySQL中create table语句的基本语法是
2007/01/15 PHP
php数组中删除元素的实现代码
2012/06/22 PHP
PHP防止跨域提交表单
2013/11/01 PHP
详解php用curl调用接口方法,get和post两种方式
2017/01/13 PHP
PHP正则删除HTML代码中宽高样式的方法
2017/06/12 PHP
Thinkphp5框架中引入Markdown编辑器操作示例
2020/06/03 PHP
jQuery 表格工具集
2010/04/25 Javascript
jQuery 获取浏览器所在的IP地址的小例子
2013/11/08 Javascript
jQuery的context属性用法实例
2014/12/27 Javascript
js实现透明度渐变效果的方法
2015/04/10 Javascript
vue超时计算的组件实例代码
2018/07/09 Javascript
JS动画实现回调地狱promise的实例代码详解
2018/11/08 Javascript
微信小程序学习笔记之跳转页面、传递参数获得数据操作图文详解
2019/03/28 Javascript
Vue组件通信的几种实现方法
2019/04/25 Javascript
vue以组件或者插件的形式实现throttle或者debounce
2019/05/22 Javascript
[04:38]完美世界携手游戏风云打造 卡尔工作室饰品系统篇
2013/04/25 DOTA
[06:42]DOTA2每周TOP10 精彩击杀集锦vol.1
2014/06/25 DOTA
python聊天程序实例代码分享
2013/11/18 Python
通过数据库对Django进行删除字段和删除模型的操作
2015/07/21 Python
对python中for、if、while的区别与比较方法
2018/06/25 Python
python解析json串与正则匹配对比方法
2018/12/20 Python
python 中如何获取列表的索引
2019/07/02 Python
python二维键值数组生成转json的例子
2019/12/06 Python
python学习将数据写入文件并保存方法
2020/06/07 Python
使用phonegap操作数据库的实现方法
2017/03/31 HTML / CSS
美国购买和销售礼品卡平台:Raise
2017/01/13 全球购物
大学生个人自我鉴定
2013/12/03 职场文书
婚礼主持词
2014/03/13 职场文书
食品采购员岗位职责
2014/04/14 职场文书
个人典型事迹材料
2014/12/30 职场文书
装饰技术负责人岗位职责
2015/04/13 职场文书
结婚通知短信怎么写
2015/04/17 职场文书
招商银行工作证明
2015/06/17 职场文书
简单聊聊Golang中defer预计算参数
2022/03/25 Golang
angular4实现带搜索的下拉框
2022/03/25 Javascript