基于ID3决策树算法的实现(Python版)


Posted in Python onMay 31, 2017

实例如下:

# -*- coding:utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#计算数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  #给所有可能分类创建字典
  for featVec in dataSet:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
  shannonEnt=0.0
  #以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
  return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
  retDataSet=[]
  for featVec in dataSet:
    if featVec[axis]==value:
      reducedFeatVec=featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
  retDataSet=[]
  for featVec in dataSet:
    if direction==0:
      if featVec[axis]>value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    else:
      if featVec[axis]<=value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
  return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)
  bestInfoGain=0.0
  bestFeature=-1
  bestSplitDict={}
  for i in range(numFeatures):
    featList=[example[i] for example in dataSet]
    #对连续型特征进行处理
    if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
      #产生n-1个候选划分点
      sortfeatList=sorted(featList)
      splitList=[]
      for j in range(len(sortfeatList)-1):
        splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

      bestSplitEntropy=10000
      slen=len(splitList)
      #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value=splitList[j]
        newEntropy=0.0
        subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
        subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
        prob0=len(subDataSet0)/float(len(dataSet))
        newEntropy+=prob0*calcShannonEnt(subDataSet0)
        prob1=len(subDataSet1)/float(len(dataSet))
        newEntropy+=prob1*calcShannonEnt(subDataSet1)
        if newEntropy<bestSplitEntropy:
          bestSplitEntropy=newEntropy
          bestSplit=j
      #用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]]=splitList[bestSplit]
      infoGain=baseEntropy-bestSplitEntropy
    #对离散型特征进行处理
    else:
      uniqueVals=set(featList)
      newEntropy=0.0
      #计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,i,value)
        prob=len(subDataSet)/float(len(dataSet))
        newEntropy+=prob*calcShannonEnt(subDataSet)
      infoGain=baseEntropy-newEntropy
    if infoGain>bestInfoGain:
      bestInfoGain=infoGain
      bestFeature=i
  #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  #即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    bestSplitValue=bestSplitDict[labels[bestFeature]]
    labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature]<=bestSplitValue:
        dataSet[i][bestFeature]=1
      else:
        dataSet[i][bestFeature]=0
  return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  if type(dataSet[0][bestFeat]).__name__=='str':
    currentlabel=labels_full.index(labels[bestFeat])
    featValuesFull=[example[currentlabel] for example in data_full]
    uniqueValsFull=set(featValuesFull)
  del(labels[bestFeat])
  #针对bestFeat的每个取值,划分出一个子树。
  for value in uniqueVals:
    subLabels=labels[:]
    if type(dataSet[0][bestFeat]).__name__=='str':
      uniqueValsFull.remove(value)
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
     (dataSet,bestFeat,value),subLabels,data_full,labels_full)
  if type(dataSet[0][bestFeat]).__name__=='str':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value]=majorityCnt(classList)
  return myTree

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)
print(myTree)
createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}

得到的决策树如下:

基于ID3决策树算法的实现(Python版)

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python益智游戏计算汉诺塔问题示例
Mar 05 Python
Python中实现两个字典(dict)合并的方法
Sep 23 Python
Python中AND、OR的一个使用小技巧
Feb 18 Python
TensorFlow变量管理详解
Mar 10 Python
基于python实现学生管理系统
Oct 17 Python
在Django admin中编辑ManyToManyField的实现方法
Aug 09 Python
django框架cookie和session用法实例详解
Dec 10 Python
Python爬虫库requests获取响应内容、响应状态码、响应头
Jan 25 Python
python 非线性规划方式(scipy.optimize.minimize)
Feb 11 Python
python实现将两个文件夹合并至另一个文件夹(制作数据集)
Apr 03 Python
python cookie反爬处理的实现
Nov 01 Python
python爬虫判断招聘信息是否存在的实例代码
Nov 20 Python
Python基础知识_浅谈用户交互
May 31 #Python
python数据类型_字符串常用操作(详解)
May 30 #Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 #Python
node.js获取参数的常用方法(总结)
May 29 #Python
老生常谈python函数参数的区别(必看篇)
May 29 #Python
Python进阶_关于命名空间与作用域(详解)
May 29 #Python
浅谈对yield的初步理解
May 29 #Python
You might like
PHP读取XML值的代码(推荐)
2011/01/01 PHP
PHP7多线程搭建教程
2017/04/21 PHP
图片之间的切换
2006/06/26 Javascript
OfflineSave离线保存代码再次发布使用说明
2007/05/23 Javascript
event.keyCode键码值表 附只能输入特定的字符串代码
2009/05/15 Javascript
Javascript闭包与函数柯里化浅析
2016/06/22 Javascript
浅谈js中的三种继承方式及其优缺点
2016/08/10 Javascript
AngularJS中ng-class用法实例分析
2017/07/06 Javascript
vue v-model动态生成详解
2018/06/30 Javascript
Angular中的ng-template及angular 使用ngTemplateOutlet 指令的方法
2018/08/08 Javascript
Node.js中的不安全跳转如何防御详解
2018/10/21 Javascript
nodejs中方法和模块用法示例
2018/12/24 NodeJs
Webpack中SplitChunksPlugin 配置参数详解
2020/03/24 Javascript
JS实现数据动态渲染的竖向步骤条
2020/06/24 Javascript
Vant picker 多级联动操作
2020/11/02 Javascript
[42:27]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#2Fnatic VS OG第三局
2016/03/05 DOTA
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
2014/06/09 Python
Windows和Linux下Python输出彩色文字的方法教程
2017/05/02 Python
Python使用win32 COM实现Excel的写入与保存功能示例
2018/05/03 Python
Python实现带下标索引的遍历操作示例
2019/05/30 Python
使用Python画股票的K线图的方法步骤
2019/06/28 Python
Python MySQL 日期时间格式化作为参数的操作
2020/03/02 Python
python 数据库查询返回list或tuple实例
2020/05/15 Python
python框架flask入门之环境搭建及开启调试
2020/06/07 Python
Python 使用SFTP和FTP实现对服务器的文件下载功能
2020/12/17 Python
中国第一家杂志折扣订阅网:杂志铺
2016/08/30 全球购物
蒂芙尼澳大利亚官方网站:Tiffany&Co. Australia
2017/08/27 全球购物
Desigual德国官网:在线购买原创服装
2018/03/27 全球购物
C#面试题问题集
2016/04/02 面试题
动物学专业毕业生求职信
2013/10/11 职场文书
主治医师岗位职责
2013/12/10 职场文书
办公室文书岗位职责
2013/12/16 职场文书
简历上的自我评价
2014/02/03 职场文书
实习生岗位职责
2014/04/12 职场文书
创业计划书之婴幼儿游泳馆
2019/09/11 职场文书
如何利用 CSS Overview 面板重构优化你的网站
2021/10/24 HTML / CSS