基于ID3决策树算法的实现(Python版)


Posted in Python onMay 31, 2017

实例如下:

# -*- coding:utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#计算数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  #给所有可能分类创建字典
  for featVec in dataSet:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
  shannonEnt=0.0
  #以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
  return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
  retDataSet=[]
  for featVec in dataSet:
    if featVec[axis]==value:
      reducedFeatVec=featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
  retDataSet=[]
  for featVec in dataSet:
    if direction==0:
      if featVec[axis]>value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    else:
      if featVec[axis]<=value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
  return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)
  bestInfoGain=0.0
  bestFeature=-1
  bestSplitDict={}
  for i in range(numFeatures):
    featList=[example[i] for example in dataSet]
    #对连续型特征进行处理
    if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
      #产生n-1个候选划分点
      sortfeatList=sorted(featList)
      splitList=[]
      for j in range(len(sortfeatList)-1):
        splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

      bestSplitEntropy=10000
      slen=len(splitList)
      #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value=splitList[j]
        newEntropy=0.0
        subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
        subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
        prob0=len(subDataSet0)/float(len(dataSet))
        newEntropy+=prob0*calcShannonEnt(subDataSet0)
        prob1=len(subDataSet1)/float(len(dataSet))
        newEntropy+=prob1*calcShannonEnt(subDataSet1)
        if newEntropy<bestSplitEntropy:
          bestSplitEntropy=newEntropy
          bestSplit=j
      #用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]]=splitList[bestSplit]
      infoGain=baseEntropy-bestSplitEntropy
    #对离散型特征进行处理
    else:
      uniqueVals=set(featList)
      newEntropy=0.0
      #计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,i,value)
        prob=len(subDataSet)/float(len(dataSet))
        newEntropy+=prob*calcShannonEnt(subDataSet)
      infoGain=baseEntropy-newEntropy
    if infoGain>bestInfoGain:
      bestInfoGain=infoGain
      bestFeature=i
  #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  #即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    bestSplitValue=bestSplitDict[labels[bestFeature]]
    labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature]<=bestSplitValue:
        dataSet[i][bestFeature]=1
      else:
        dataSet[i][bestFeature]=0
  return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  if type(dataSet[0][bestFeat]).__name__=='str':
    currentlabel=labels_full.index(labels[bestFeat])
    featValuesFull=[example[currentlabel] for example in data_full]
    uniqueValsFull=set(featValuesFull)
  del(labels[bestFeat])
  #针对bestFeat的每个取值,划分出一个子树。
  for value in uniqueVals:
    subLabels=labels[:]
    if type(dataSet[0][bestFeat]).__name__=='str':
      uniqueValsFull.remove(value)
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
     (dataSet,bestFeat,value),subLabels,data_full,labels_full)
  if type(dataSet[0][bestFeat]).__name__=='str':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value]=majorityCnt(classList)
  return myTree

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)
print(myTree)
createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}

得到的决策树如下:

基于ID3决策树算法的实现(Python版)

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作MySQL数据库的方法分享
May 29 Python
python 实现插入排序算法
Jun 05 Python
详解python3中socket套接字的编码问题解决
Jul 01 Python
Python简单实现自动删除目录下空文件夹的方法
Aug 29 Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 Python
基于Python3.6+splinter实现自动抢火车票
Sep 25 Python
pandas条件组合筛选和按范围筛选的示例代码
Aug 26 Python
Python3+Requests+Excel完整接口自动化测试框架的实现
Oct 11 Python
使用python实现画AR模型时序图
Nov 20 Python
Python -m参数原理及使用方法解析
Aug 21 Python
linux mint中搜狗输入法导致pycharm卡死的问题
Oct 28 Python
Python之Matplotlib绘制热力图和面积图
Apr 13 Python
Python基础知识_浅谈用户交互
May 31 #Python
python数据类型_字符串常用操作(详解)
May 30 #Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 #Python
node.js获取参数的常用方法(总结)
May 29 #Python
老生常谈python函数参数的区别(必看篇)
May 29 #Python
Python进阶_关于命名空间与作用域(详解)
May 29 #Python
浅谈对yield的初步理解
May 29 #Python
You might like
一个odbc连mssql分页的类
2006/10/09 PHP
PHP 编写的 25个游戏脚本
2009/05/11 PHP
php利用递归实现删除文件目录的方法
2016/09/23 PHP
JavaScript事件列表解说
2006/12/22 Javascript
静态的动态续篇之来点XML
2006/12/23 Javascript
js tab 选项卡
2009/04/26 Javascript
jquery下onpropertychange事件的绑定方法
2010/08/01 Javascript
JS日期和时间选择控件升级版(自写)
2013/08/02 Javascript
javascript表单验证使用示例(javascript验证邮箱)
2014/01/07 Javascript
使用JS画图之点、线、面
2015/01/12 Javascript
JS使用ajax方法获取指定url的head信息中指定字段值的方法
2015/03/24 Javascript
如何使用jQuery技术开发ios风格的页面导航菜单
2015/07/29 Javascript
JS实现本地存储信息的方法(基于localStorage与userData)
2017/02/18 Javascript
nodejs中方法和模块用法示例
2018/12/24 NodeJs
微信小程序实现联动选择器
2019/02/15 Javascript
vue项目使用高德地图的定位及关键字搜索功能的实例代码(踩坑经验)
2020/03/07 Javascript
[01:37]DOTA2超级联赛专访ChuaN 传奇般的电竞之路
2013/06/19 DOTA
Python查找函数f(x)=0根的解决方法
2015/05/07 Python
python删除过期文件的方法
2015/05/29 Python
Python的Django中django-userena组件的简单使用教程
2015/05/30 Python
详解Python中最难理解的点-装饰器
2017/04/03 Python
5个很好的Python面试题问题答案及分析
2018/01/19 Python
Python实现的视频播放器功能完整示例
2018/02/01 Python
python+selenium实现自动抢票功能实例代码
2018/11/23 Python
pandas计算最大连续间隔的方法
2019/07/04 Python
Python PIL库图片灰化处理
2020/04/07 Python
基于Python编写一个计算器程序,实现简单的加减乘除和取余二元运算
2020/08/05 Python
Python 实现PS滤镜的旋涡特效
2020/12/03 Python
今天学到的CSS最新技术(与图片背景相关)
2012/12/24 HTML / CSS
美国运动鞋和运动服零售商:Footaction
2017/04/07 全球购物
小蚁科技官方商店:YI Technology
2019/08/23 全球购物
成人高等教育毕业生自我鉴定
2013/10/22 职场文书
乡镇领导干部个人对照检查材料思想汇报
2014/09/23 职场文书
开展党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
上课迟到检讨书范文
2015/05/06 职场文书
深入理解margin塌陷和margin合并的解决方案
2021/06/26 HTML / CSS