基于ID3决策树算法的实现(Python版)


Posted in Python onMay 31, 2017

实例如下:

# -*- coding:utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#计算数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  #给所有可能分类创建字典
  for featVec in dataSet:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
  shannonEnt=0.0
  #以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
  return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
  retDataSet=[]
  for featVec in dataSet:
    if featVec[axis]==value:
      reducedFeatVec=featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
  retDataSet=[]
  for featVec in dataSet:
    if direction==0:
      if featVec[axis]>value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    else:
      if featVec[axis]<=value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
  return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)
  bestInfoGain=0.0
  bestFeature=-1
  bestSplitDict={}
  for i in range(numFeatures):
    featList=[example[i] for example in dataSet]
    #对连续型特征进行处理
    if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
      #产生n-1个候选划分点
      sortfeatList=sorted(featList)
      splitList=[]
      for j in range(len(sortfeatList)-1):
        splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

      bestSplitEntropy=10000
      slen=len(splitList)
      #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value=splitList[j]
        newEntropy=0.0
        subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
        subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
        prob0=len(subDataSet0)/float(len(dataSet))
        newEntropy+=prob0*calcShannonEnt(subDataSet0)
        prob1=len(subDataSet1)/float(len(dataSet))
        newEntropy+=prob1*calcShannonEnt(subDataSet1)
        if newEntropy<bestSplitEntropy:
          bestSplitEntropy=newEntropy
          bestSplit=j
      #用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]]=splitList[bestSplit]
      infoGain=baseEntropy-bestSplitEntropy
    #对离散型特征进行处理
    else:
      uniqueVals=set(featList)
      newEntropy=0.0
      #计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,i,value)
        prob=len(subDataSet)/float(len(dataSet))
        newEntropy+=prob*calcShannonEnt(subDataSet)
      infoGain=baseEntropy-newEntropy
    if infoGain>bestInfoGain:
      bestInfoGain=infoGain
      bestFeature=i
  #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  #即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    bestSplitValue=bestSplitDict[labels[bestFeature]]
    labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature]<=bestSplitValue:
        dataSet[i][bestFeature]=1
      else:
        dataSet[i][bestFeature]=0
  return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  if type(dataSet[0][bestFeat]).__name__=='str':
    currentlabel=labels_full.index(labels[bestFeat])
    featValuesFull=[example[currentlabel] for example in data_full]
    uniqueValsFull=set(featValuesFull)
  del(labels[bestFeat])
  #针对bestFeat的每个取值,划分出一个子树。
  for value in uniqueVals:
    subLabels=labels[:]
    if type(dataSet[0][bestFeat]).__name__=='str':
      uniqueValsFull.remove(value)
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
     (dataSet,bestFeat,value),subLabels,data_full,labels_full)
  if type(dataSet[0][bestFeat]).__name__=='str':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value]=majorityCnt(classList)
  return myTree

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)
print(myTree)
createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}

得到的决策树如下:

基于ID3决策树算法的实现(Python版)

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用pil进行图像处理(等比例压缩、裁剪)实例代码
Dec 11 Python
Python基于辗转相除法求解最大公约数的方法示例
Apr 04 Python
Tensorflow卷积神经网络实例进阶
May 24 Python
Python3之简单搭建自带服务器的实例讲解
Jun 04 Python
Django 中使用流响应处理视频的方法
Jul 20 Python
Python3按一定数据位数格式处理bin文件的方法
Jan 24 Python
详解python 爬取12306验证码
May 10 Python
Pandas_cum累积计算和rolling滚动计算的用法详解
Jul 04 Python
python函数的万能参数传参详解
Jul 26 Python
PIL.Image.open和cv2.imread的比较与相互转换的方法
Jun 03 Python
Python爬虫后获取重定向url的两种方法
Jan 19 Python
python之基数排序的实现
Jul 26 Python
Python基础知识_浅谈用户交互
May 31 #Python
python数据类型_字符串常用操作(详解)
May 30 #Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 #Python
node.js获取参数的常用方法(总结)
May 29 #Python
老生常谈python函数参数的区别(必看篇)
May 29 #Python
Python进阶_关于命名空间与作用域(详解)
May 29 #Python
浅谈对yield的初步理解
May 29 #Python
You might like
destoon调用自定义模板及样式的公告栏
2014/06/21 PHP
php中ob_flush函数和flush函数用法分析
2015/03/18 PHP
PHP处理会话函数大总结
2015/08/05 PHP
用javascript来实现动画导航效果的代码
2007/12/16 Javascript
JavaScript DOM 学习第九章 选取范围的介绍
2010/02/19 Javascript
JavaScript子窗口ModalDialog中操作父窗口对像
2012/12/11 Javascript
再探JavaScript作用域
2014/09/24 Javascript
javascript实现的固定位置悬浮窗口实例
2015/04/30 Javascript
jquery使用经验小结
2015/05/20 Javascript
Nodejs实战心得之eventproxy模块控制并发
2015/10/27 NodeJs
JS实现网页上随滚动条滚动的层效果代码
2015/11/04 Javascript
学习Bootstrap滚动监听 附调用方法
2016/07/02 Javascript
Angularjs结合Bootstrap制作的一个TODO List
2016/08/18 Javascript
Ajax异步获取html数据中包含js方法无效的解决方法
2017/02/20 Javascript
nodejs个人博客开发第三步 载入页面
2017/04/12 NodeJs
JS如何设置元素样式的方法示例
2017/08/28 Javascript
详解webpack + vue + node 打造单页面(入门篇)
2017/09/23 Javascript
《javascript少儿编程》location术语总结
2018/05/27 Javascript
vue cli 3.0 使用全过程解析
2018/06/14 Javascript
微信小程序点击列表跳转到对应详情页过程解析
2019/09/26 Javascript
VUE页面中通过双击实现复制表格中内容的示例代码
2020/06/11 Javascript
Django框架下在URLconf中指定视图缓存的方法
2015/07/23 Python
Python使用matplotlib模块绘制图像并设置标题与坐标轴等信息示例
2018/05/04 Python
python 字典操作提取key,value的方法
2019/06/26 Python
python中查看.db文件中表格的名字及表格中的字段操作
2020/07/07 Python
HTML5 新旧语法标记对我们有什么好处
2012/12/13 HTML / CSS
英国经济型酒店品牌:Travelodge
2019/12/17 全球购物
电子商务专业个人的自我评价
2013/11/19 职场文书
学生会主席就职演讲稿
2014/01/14 职场文书
市政施工员自我鉴定
2014/01/15 职场文书
啤酒节策划方案
2014/05/28 职场文书
董事长助理工作职责范本
2014/07/01 职场文书
暑假开始了,你的暑假学习计划写好了吗?
2019/07/04 职场文书
想创业成功,需要掌握这些要点
2019/12/06 职场文书
读《茶花女》有感:山茶花的盛开与凋零
2020/01/17 职场文书
JS实现九宫格拼图游戏
2022/06/28 Javascript