机器学习python实战之决策树


Posted in Python onNovember 01, 2017

决策树原理:从数据集中找出决定性的特征对数据集进行迭代划分,直到某个分支下的数据都属于同一类型,或者已经遍历了所有划分数据集的特征,停止决策树算法。

每次划分数据集的特征都有很多,那么我们怎么来选择到底根据哪一个特征划分数据集呢?这里我们需要引入信息增益和信息熵的概念。

一、信息增益

划分数据集的原则是:将无序的数据变的有序。在划分数据集之前之后信息发生的变化称为信息增益。知道如何计算信息增益,我们就可以计算根据每个特征划分数据集获得的信息增益,选择信息增益最高的特征就是最好的选择。首先我们先来明确一下信息的定义:符号xi的信息定义为 l(xi)=-log2 p(xi),p(xi)为选择该类的概率。那么信息源的熵H=-∑p(xi)·log2 p(xi)。根据这个公式我们下面编写代码计算香农熵

def calcShannonEnt(dataSet):
 NumEntries = len(dataSet)
 labelsCount = {}
 for i in dataSet:
  currentlabel = i[-1]
  if currentlabel not in labelsCount.keys():
   labelsCount[currentlabel]=0
  labelsCount[currentlabel]+=1
 ShannonEnt = 0.0
 for key in labelsCount:
  prob = labelsCount[key]/NumEntries
  ShannonEnt -= prob*log(prob,2)
 return ShannonEnt

上面的自定义函数我们需要在之前导入log方法,from math import log。 我们可以先用一个简单的例子来测试一下

def createdataSet():
 #dataSet = [['1','1','yes'],['1','0','no'],['0','1','no'],['0','0','no']]
 dataSet = [[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]
 labels = ['no surfacing','flippers']
 return dataSet,labels

机器学习python实战之决策树

这里的熵为0.811,当我们增加数据的类别时,熵会增加。这里更改后的数据集的类别有三种‘yes'、‘no'、‘maybe',也就是说数据越混乱,熵就越大。

机器学习python实战之决策树

分类算法出了需要计算信息熵,还需要划分数据集。决策树算法中我们对根据每个特征划分的数据集计算一次熵,然后判断按照哪个特征划分是最好的划分方式。

def splitDataSet(dataSet,axis,value):
 retDataSet = []
 for featVec in dataSet:
  if featVec[axis] == value:
   reducedfeatVec = featVec[:axis]
   reducedfeatVec.extend(featVec[axis+1:])
   retDataSet.append(reducedfeatVec)
 return retDataSet

axis表示划分数据集的特征,value表示特征的返回值。这里需要注意extend方法和append方法的区别。举例来说明这个区别

机器学习python实战之决策树

下面我们测试一下划分数据集函数的结果:

机器学习python实战之决策树

axis=0,value=1,按myDat数据集的第0个特征向量是否等于1进行划分。

接下来我们将遍历整个数据集,对每个划分的数据集计算香农熵,找到最好的特征划分方式

def choosebestfeatureToSplit(dataSet):
 Numfeatures = len(dataSet)-1
 BaseShannonEnt = calcShannonEnt(dataSet)
 bestInfoGain=0.0
 bestfeature = -1
 for i in range(Numfeatures):
  featlist = [example[i] for example in dataSet]
  featSet = set(featlist)
  newEntropy = 0.0
  for value in featSet:
   subDataSet = splitDataSet(dataSet,i,value)
   prob = len(subDataSet)/len(dataSet)
   newEntropy += prob*calcShannonEnt(subDataSet) 
  infoGain = BaseShannonEnt-newEntropy
  if infoGain>bestInfoGain:
   bestInfoGain=infoGain
   bestfeature = i
 return bestfeature

信息增益是熵的减少或数据无序度的减少。最后比较所有特征中的信息增益,返回最好特征划分的索引。函数测试结果为

机器学习python实战之决策树

接下来开始递归构建决策树,我们需要在构建前计算列的数目,查看算法是否使用了所有的属性。这个函数跟跟第二章的calssify0采用同样的方法

def majorityCnt(classlist):
 ClassCount = {}
 for vote in classlist:
  if vote not in ClassCount.keys():
   ClassCount[vote]=0
  ClassCount[vote]+=1
 sortedClassCount = sorted(ClassCount.items(),key = operator.itemgetter(1),reverse = True)
 return sortedClassCount[0][0]

def createTrees(dataSet,labels):
 classList = [example[-1] for example in dataSet]
 if classList.count(classList[0]) == len(classList):
  return classList[0]
 if len(dataSet[0])==1:
  return majorityCnt(classList)
 bestfeature = choosebestfeatureToSplit(dataSet)
 bestfeatureLabel = labels[bestfeature]
 myTree = {bestfeatureLabel:{}}
 del(labels[bestfeature])
 featValue = [example[bestfeature] for example in dataSet]
 uniqueValue = set(featValue)
 for value in uniqueValue:
  subLabels = labels[:]
  myTree[bestfeatureLabel][value] = createTrees(splitDataSet(dataSet,bestfeature,value),subLabels)
 return myTree

最终决策树得到的结果如下:

机器学习python实战之决策树

有了如上的结果,我们看起来并不直观,所以我们接下来用matplotlib注解绘制树形图。matplotlib提供了一个注解工具annotations,它可以在数据图形上添加文本注释。我们先来测试一下这个注解工具的使用。

import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = 'sawtooth',fc = '0.8')
leafNode = dict(boxstyle = 'sawtooth',fc = '0.8')
arrow_args = dict(arrowstyle = '<-')

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
 createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',\
       xytext = centerPt,textcoords = 'axes fraction',\
       va = 'center',ha = 'center',bbox = nodeType,\
       arrowprops = arrow_args)
 
def createPlot():
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 createPlot.ax1 = plt.subplot(111,frameon = False)
 plotNode('test1',(0.5,0.1),(0.1,0.5),decisionNode)
 plotNode('test2',(0.8,0.1),(0.3,0.8),leafNode)
 plt.show()

机器学习python实战之决策树

测试过这个小例子之后我们就要开始构建注解树了。虽然有xy坐标,但在如何放置树节点的时候我们会遇到一些麻烦。所以我们需要知道有多少个叶节点,树的深度有多少层。下面的两个函数就是为了得到叶节点数目和树的深度,两个函数有相同的结构,从第一个关键字开始遍历所有的子节点,使用type()函数判断子节点是否为字典类型,若为字典类型,则可以认为该子节点是一个判断节点,然后递归调用函数getNumleafs(),使得函数遍历整棵树,并返回叶子节点数。第2个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一

def getNumleafs(myTree):
 numLeafs=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__=='dict':
   numLeafs+=getNumleafs(secondDict[key])
  else:
   numLeafs+=1
 return numLeafs

def getTreeDepth(myTree):
 maxdepth=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   thedepth=1+getTreeDepth(secondDict[key])
  else:
   thedepth=1
  if thedepth>maxdepth:
   maxdepth=thedepth
 return maxdepth

测试结果如下

机器学习python实战之决策树

我们先给出最终的决策树图来验证上述结果的正确性

机器学习python实战之决策树

可以看出树的深度确实是有两层,叶节点的数目是3。接下来我们给出绘制决策树图的关键函数,结果就得到上图中决策树。

def plotMidText(cntrPt,parentPt,txtString):
 xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
 yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
 createPlot.ax1.text(xMid,yMid,txtString)
 
def plotTree(myTree,parentPt,nodeTxt):
 numLeafs = getNumleafs(myTree)
 depth = getTreeDepth(myTree)
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
 plotMidText(cntrPt,parentPt,nodeTxt)
 plotNode(firstStr,cntrPt,parentPt,decisionNode)
 secondDict = myTree[firstStr]
 plotTree.yOff -= 1.0/plotTree.totalD
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   plotTree(secondDict[key],cntrPt,str(key))
  else:
   plotTree.xOff+=1.0/plotTree.totalW
   plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
   plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
 plotTree.yOff+=1.0/plotTree.totalD
 
def createPlot(inTree):
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 axprops = dict(xticks = [],yticks = [])
 createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
 plotTree.totalW = float(getNumleafs(inTree))
 plotTree.totalD = float(getTreeDepth(inTree))
 plotTree.xOff = -0.5/ plotTree.totalW; plotTree.yOff = 1.0
 plotTree(inTree,(0.5,1.0),'')
 plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中除法使用的注意事项
Aug 21 Python
Python文件操作基本流程代码实例
Dec 11 Python
python和shell监控linux服务器的详细代码
Jun 22 Python
python抓取京东小米8手机配置信息
Nov 13 Python
Python中is和==的区别详解
Nov 15 Python
python实现Excel文件转换为TXT文件
Apr 28 Python
Python MongoDB 插入数据时已存在则不执行,不存在则插入的解决方法
Sep 24 Python
python 实现将Numpy数组保存为图像
Jan 09 Python
IntelliJ 中配置 Anaconda的过程图解
Jun 01 Python
Python爬虫谷歌Chrome F12抓包过程原理解析
Jun 04 Python
python删除csv文件的行列
Apr 06 Python
python blinker 信号库
May 04 Python
详解Python开发中如何使用Hook技巧
Nov 01 #Python
python利用标准库如何获取本地IP示例详解
Nov 01 #Python
你眼中的Python大牛 应该都有这份书单
Oct 31 #Python
Python生成数字图片代码分享
Oct 31 #Python
python使用标准库根据进程名如何获取进程的pid详解
Oct 31 #Python
Python列表删除的三种方法代码分享
Oct 31 #Python
Python文件的读写和异常代码示例
Oct 31 #Python
You might like
php学习之 循环结构实现代码
2011/06/09 PHP
PHP删除HTMl标签的三种解决方法
2013/06/30 PHP
使用php实现从身份证中提取生日
2016/05/09 PHP
Zend Framework常用校验器详解
2016/12/09 PHP
详解no input file specified 三种解决方法
2019/11/29 PHP
限制复选框的最大可选数
2006/07/01 Javascript
js每次Title显示不同的名言
2008/09/25 Javascript
jquery $(document).ready() 与window.onload的区别
2009/12/28 Javascript
ExtJs GridPanel简单的增删改实现代码
2010/08/26 Javascript
JavaScript高级程序设计(第3版)学习笔记2 js基础语法
2012/10/11 Javascript
推荐一款jQuery插件模板
2015/01/09 Javascript
php常见的页面跳转方法汇总
2015/04/15 Javascript
基于JQuery实现图片上传预览与删除操作
2016/05/24 Javascript
微信小程序之电影影评小程序制作代码
2017/08/03 Javascript
一篇文章让你彻底弄懂JS的事件冒泡和事件捕获
2017/08/14 Javascript
vue项目优化之通过keep-alive数据缓存的方法
2017/12/11 Javascript
使用JQuery自动完成插件Auto Complete详解
2019/06/18 jQuery
简单了解Vue + ElementUI后台管理模板
2020/04/07 Javascript
webpack4从0搭建组件库的实现
2020/11/29 Javascript
[02:19]2014DOTA2国际邀请赛 专访820少年们一起去追梦吧
2014/07/14 DOTA
Python操作MySQL数据库9个实用实例
2015/12/11 Python
详解Python爬虫的基本写法
2016/01/08 Python
python实现一组典型数据格式转换
2018/12/15 Python
对Python闭包与延迟绑定的方法详解
2019/01/07 Python
Python+PyQt5实现美剧爬虫可视工具的方法
2019/04/25 Python
创建Django项目图文实例详解
2019/06/06 Python
python hashlib加密实现代码
2019/10/17 Python
浅析Python requests 模块
2020/10/09 Python
优秀共产党员先进事迹
2014/01/27 职场文书
庆元旦迎新年广播稿
2014/02/18 职场文书
人事专员岗位职责说明书
2014/07/30 职场文书
物业项目经理岗位职责
2015/04/01 职场文书
2015年平安创建工作总结
2015/04/29 职场文书
勇敢的心观后感
2015/06/09 职场文书
欠条样本
2015/07/03 职场文书
app场景下uniapp的扫码记录
2022/07/23 Java/Android