基于ID3决策树算法的实现(Python版)


Posted in Python onMay 31, 2017

实例如下:

# -*- coding:utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#计算数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  #给所有可能分类创建字典
  for featVec in dataSet:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
  shannonEnt=0.0
  #以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
  return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
  retDataSet=[]
  for featVec in dataSet:
    if featVec[axis]==value:
      reducedFeatVec=featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
  retDataSet=[]
  for featVec in dataSet:
    if direction==0:
      if featVec[axis]>value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    else:
      if featVec[axis]<=value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
  return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)
  bestInfoGain=0.0
  bestFeature=-1
  bestSplitDict={}
  for i in range(numFeatures):
    featList=[example[i] for example in dataSet]
    #对连续型特征进行处理
    if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
      #产生n-1个候选划分点
      sortfeatList=sorted(featList)
      splitList=[]
      for j in range(len(sortfeatList)-1):
        splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

      bestSplitEntropy=10000
      slen=len(splitList)
      #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value=splitList[j]
        newEntropy=0.0
        subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
        subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
        prob0=len(subDataSet0)/float(len(dataSet))
        newEntropy+=prob0*calcShannonEnt(subDataSet0)
        prob1=len(subDataSet1)/float(len(dataSet))
        newEntropy+=prob1*calcShannonEnt(subDataSet1)
        if newEntropy<bestSplitEntropy:
          bestSplitEntropy=newEntropy
          bestSplit=j
      #用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]]=splitList[bestSplit]
      infoGain=baseEntropy-bestSplitEntropy
    #对离散型特征进行处理
    else:
      uniqueVals=set(featList)
      newEntropy=0.0
      #计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,i,value)
        prob=len(subDataSet)/float(len(dataSet))
        newEntropy+=prob*calcShannonEnt(subDataSet)
      infoGain=baseEntropy-newEntropy
    if infoGain>bestInfoGain:
      bestInfoGain=infoGain
      bestFeature=i
  #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  #即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    bestSplitValue=bestSplitDict[labels[bestFeature]]
    labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature]<=bestSplitValue:
        dataSet[i][bestFeature]=1
      else:
        dataSet[i][bestFeature]=0
  return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  if type(dataSet[0][bestFeat]).__name__=='str':
    currentlabel=labels_full.index(labels[bestFeat])
    featValuesFull=[example[currentlabel] for example in data_full]
    uniqueValsFull=set(featValuesFull)
  del(labels[bestFeat])
  #针对bestFeat的每个取值,划分出一个子树。
  for value in uniqueVals:
    subLabels=labels[:]
    if type(dataSet[0][bestFeat]).__name__=='str':
      uniqueValsFull.remove(value)
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
     (dataSet,bestFeat,value),subLabels,data_full,labels_full)
  if type(dataSet[0][bestFeat]).__name__=='str':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value]=majorityCnt(classList)
  return myTree

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)
print(myTree)
createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}

得到的决策树如下:

基于ID3决策树算法的实现(Python版)

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的一个找零钱的小程序代码分享
Aug 25 Python
简单分析Python中用fork()函数生成的子进程
May 04 Python
python定时执行指定函数的方法
May 27 Python
python制作爬虫并将抓取结果保存到excel中
Apr 06 Python
Django ORM 聚合查询和分组查询实现详解
Aug 09 Python
python中的反斜杠问题深入讲解
Aug 12 Python
Python手绘可视化工具cutecharts使用实例
Dec 05 Python
python中property和setter装饰器用法
Dec 19 Python
python进行参数传递的方法
May 12 Python
keras做CNN的训练误差loss的下降操作
Jun 22 Python
Python代码覆盖率统计工具coverage.py用法详解
Nov 25 Python
Python的scikit-image模块实例讲解
Dec 30 Python
Python基础知识_浅谈用户交互
May 31 #Python
python数据类型_字符串常用操作(详解)
May 30 #Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 #Python
node.js获取参数的常用方法(总结)
May 29 #Python
老生常谈python函数参数的区别(必看篇)
May 29 #Python
Python进阶_关于命名空间与作用域(详解)
May 29 #Python
浅谈对yield的初步理解
May 29 #Python
You might like
随机广告显示(PHP函数)
2006/10/09 PHP
CentOS安装php v8js教程
2015/02/26 PHP
PHP获取网站中各文章的第一张图片的代码示例
2016/05/20 PHP
PHP递归的三种常用方式
2019/02/28 PHP
window.navigate 与 window.location.href 的使用区别介绍
2013/09/21 Javascript
jQuery获取和设置表单元素的方法
2014/02/14 Javascript
js之ActiveX控件使用说明 new ActiveXObject()
2014/03/03 Javascript
JavaScript中扩展Array contains方法实例
2020/08/23 Javascript
分享网页检测摇一摇实例代码
2016/01/14 Javascript
jQuery动态加载css文件实现方法
2016/06/15 Javascript
jQuery实现的省市联动菜单功能示例【测试可用】
2017/01/13 Javascript
Angular2库初探
2017/03/01 Javascript
JS实现的Unicode编码转换操作示例
2017/04/28 Javascript
详解Angular-Cli中引用第三方库
2017/05/21 Javascript
微信小程序 新建登录页并实现tabBar隐藏
2017/06/13 Javascript
jQueryeasyui 中如何使用datetimebox 取两个日期间相隔的天数
2017/06/13 jQuery
微信小程序picker组件下拉框选择input输入框的实例
2017/09/20 Javascript
JS实现验证码倒计时的注册页面
2018/01/02 Javascript
webpack3里使用uglifyjs压缩js时打包报错的解决
2018/12/13 Javascript
解决vue跨域axios异步通信问题
2019/04/17 Javascript
JS实现省市县三级下拉联动
2020/04/10 Javascript
es6数组includes()用法实例分析
2020/04/18 Javascript
基于Ionic3实现选项卡切换并重新加载echarts
2020/09/24 Javascript
[01:22:28]DOTA2-DPC中国联赛 正赛 SAG vs RNG BO3 第一场 1月18日
2021/03/11 DOTA
ubuntu16.04制作vim和python3的开发环境
2018/09/23 Python
pycharm运行出现ImportError:No module named的解决方法
2018/10/13 Python
Python里字典的基本用法(包括嵌套字典)
2019/02/27 Python
Python过滤序列元素的方法
2020/07/31 Python
jupyter使用自动补全和切换默认浏览器的方法
2020/11/18 Python
利用纯html5绘制出来的一款非常漂亮的时钟
2015/01/04 HTML / CSS
添柏岚英国官方网站:Timberland英国
2019/11/28 全球购物
高速铁道技术专业求职信
2014/08/09 职场文书
工作年限证明模板
2015/06/15 职场文书
婚礼父母致辞
2015/07/28 职场文书
党员电教片《信仰》心得体会
2016/01/15 职场文书
2017年寒假社区服务活动总结
2016/04/06 职场文书