机器学习python实战之决策树


Posted in Python onNovember 01, 2017

决策树原理:从数据集中找出决定性的特征对数据集进行迭代划分,直到某个分支下的数据都属于同一类型,或者已经遍历了所有划分数据集的特征,停止决策树算法。

每次划分数据集的特征都有很多,那么我们怎么来选择到底根据哪一个特征划分数据集呢?这里我们需要引入信息增益和信息熵的概念。

一、信息增益

划分数据集的原则是:将无序的数据变的有序。在划分数据集之前之后信息发生的变化称为信息增益。知道如何计算信息增益,我们就可以计算根据每个特征划分数据集获得的信息增益,选择信息增益最高的特征就是最好的选择。首先我们先来明确一下信息的定义:符号xi的信息定义为 l(xi)=-log2 p(xi),p(xi)为选择该类的概率。那么信息源的熵H=-∑p(xi)·log2 p(xi)。根据这个公式我们下面编写代码计算香农熵

def calcShannonEnt(dataSet):
 NumEntries = len(dataSet)
 labelsCount = {}
 for i in dataSet:
  currentlabel = i[-1]
  if currentlabel not in labelsCount.keys():
   labelsCount[currentlabel]=0
  labelsCount[currentlabel]+=1
 ShannonEnt = 0.0
 for key in labelsCount:
  prob = labelsCount[key]/NumEntries
  ShannonEnt -= prob*log(prob,2)
 return ShannonEnt

上面的自定义函数我们需要在之前导入log方法,from math import log。 我们可以先用一个简单的例子来测试一下

def createdataSet():
 #dataSet = [['1','1','yes'],['1','0','no'],['0','1','no'],['0','0','no']]
 dataSet = [[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]
 labels = ['no surfacing','flippers']
 return dataSet,labels

机器学习python实战之决策树

这里的熵为0.811,当我们增加数据的类别时,熵会增加。这里更改后的数据集的类别有三种‘yes'、‘no'、‘maybe',也就是说数据越混乱,熵就越大。

机器学习python实战之决策树

分类算法出了需要计算信息熵,还需要划分数据集。决策树算法中我们对根据每个特征划分的数据集计算一次熵,然后判断按照哪个特征划分是最好的划分方式。

def splitDataSet(dataSet,axis,value):
 retDataSet = []
 for featVec in dataSet:
  if featVec[axis] == value:
   reducedfeatVec = featVec[:axis]
   reducedfeatVec.extend(featVec[axis+1:])
   retDataSet.append(reducedfeatVec)
 return retDataSet

axis表示划分数据集的特征,value表示特征的返回值。这里需要注意extend方法和append方法的区别。举例来说明这个区别

机器学习python实战之决策树

下面我们测试一下划分数据集函数的结果:

机器学习python实战之决策树

axis=0,value=1,按myDat数据集的第0个特征向量是否等于1进行划分。

接下来我们将遍历整个数据集,对每个划分的数据集计算香农熵,找到最好的特征划分方式

def choosebestfeatureToSplit(dataSet):
 Numfeatures = len(dataSet)-1
 BaseShannonEnt = calcShannonEnt(dataSet)
 bestInfoGain=0.0
 bestfeature = -1
 for i in range(Numfeatures):
  featlist = [example[i] for example in dataSet]
  featSet = set(featlist)
  newEntropy = 0.0
  for value in featSet:
   subDataSet = splitDataSet(dataSet,i,value)
   prob = len(subDataSet)/len(dataSet)
   newEntropy += prob*calcShannonEnt(subDataSet) 
  infoGain = BaseShannonEnt-newEntropy
  if infoGain>bestInfoGain:
   bestInfoGain=infoGain
   bestfeature = i
 return bestfeature

信息增益是熵的减少或数据无序度的减少。最后比较所有特征中的信息增益,返回最好特征划分的索引。函数测试结果为

机器学习python实战之决策树

接下来开始递归构建决策树,我们需要在构建前计算列的数目,查看算法是否使用了所有的属性。这个函数跟跟第二章的calssify0采用同样的方法

def majorityCnt(classlist):
 ClassCount = {}
 for vote in classlist:
  if vote not in ClassCount.keys():
   ClassCount[vote]=0
  ClassCount[vote]+=1
 sortedClassCount = sorted(ClassCount.items(),key = operator.itemgetter(1),reverse = True)
 return sortedClassCount[0][0]

def createTrees(dataSet,labels):
 classList = [example[-1] for example in dataSet]
 if classList.count(classList[0]) == len(classList):
  return classList[0]
 if len(dataSet[0])==1:
  return majorityCnt(classList)
 bestfeature = choosebestfeatureToSplit(dataSet)
 bestfeatureLabel = labels[bestfeature]
 myTree = {bestfeatureLabel:{}}
 del(labels[bestfeature])
 featValue = [example[bestfeature] for example in dataSet]
 uniqueValue = set(featValue)
 for value in uniqueValue:
  subLabels = labels[:]
  myTree[bestfeatureLabel][value] = createTrees(splitDataSet(dataSet,bestfeature,value),subLabels)
 return myTree

最终决策树得到的结果如下:

机器学习python实战之决策树

有了如上的结果,我们看起来并不直观,所以我们接下来用matplotlib注解绘制树形图。matplotlib提供了一个注解工具annotations,它可以在数据图形上添加文本注释。我们先来测试一下这个注解工具的使用。

import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = 'sawtooth',fc = '0.8')
leafNode = dict(boxstyle = 'sawtooth',fc = '0.8')
arrow_args = dict(arrowstyle = '<-')

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
 createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',\
       xytext = centerPt,textcoords = 'axes fraction',\
       va = 'center',ha = 'center',bbox = nodeType,\
       arrowprops = arrow_args)
 
def createPlot():
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 createPlot.ax1 = plt.subplot(111,frameon = False)
 plotNode('test1',(0.5,0.1),(0.1,0.5),decisionNode)
 plotNode('test2',(0.8,0.1),(0.3,0.8),leafNode)
 plt.show()

机器学习python实战之决策树

测试过这个小例子之后我们就要开始构建注解树了。虽然有xy坐标,但在如何放置树节点的时候我们会遇到一些麻烦。所以我们需要知道有多少个叶节点,树的深度有多少层。下面的两个函数就是为了得到叶节点数目和树的深度,两个函数有相同的结构,从第一个关键字开始遍历所有的子节点,使用type()函数判断子节点是否为字典类型,若为字典类型,则可以认为该子节点是一个判断节点,然后递归调用函数getNumleafs(),使得函数遍历整棵树,并返回叶子节点数。第2个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一

def getNumleafs(myTree):
 numLeafs=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__=='dict':
   numLeafs+=getNumleafs(secondDict[key])
  else:
   numLeafs+=1
 return numLeafs

def getTreeDepth(myTree):
 maxdepth=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   thedepth=1+getTreeDepth(secondDict[key])
  else:
   thedepth=1
  if thedepth>maxdepth:
   maxdepth=thedepth
 return maxdepth

测试结果如下

机器学习python实战之决策树

我们先给出最终的决策树图来验证上述结果的正确性

机器学习python实战之决策树

可以看出树的深度确实是有两层,叶节点的数目是3。接下来我们给出绘制决策树图的关键函数,结果就得到上图中决策树。

def plotMidText(cntrPt,parentPt,txtString):
 xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
 yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
 createPlot.ax1.text(xMid,yMid,txtString)
 
def plotTree(myTree,parentPt,nodeTxt):
 numLeafs = getNumleafs(myTree)
 depth = getTreeDepth(myTree)
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
 plotMidText(cntrPt,parentPt,nodeTxt)
 plotNode(firstStr,cntrPt,parentPt,decisionNode)
 secondDict = myTree[firstStr]
 plotTree.yOff -= 1.0/plotTree.totalD
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   plotTree(secondDict[key],cntrPt,str(key))
  else:
   plotTree.xOff+=1.0/plotTree.totalW
   plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
   plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
 plotTree.yOff+=1.0/plotTree.totalD
 
def createPlot(inTree):
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 axprops = dict(xticks = [],yticks = [])
 createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
 plotTree.totalW = float(getNumleafs(inTree))
 plotTree.totalD = float(getTreeDepth(inTree))
 plotTree.xOff = -0.5/ plotTree.totalW; plotTree.yOff = 1.0
 plotTree(inTree,(0.5,1.0),'')
 plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python网页解析利器BeautifulSoup安装使用介绍
Mar 17 Python
以一个投票程序的实例来讲解Python的Django框架使用
Feb 18 Python
python Celery定时任务的示例
Mar 13 Python
使用Python读取二进制文件的实例讲解
Jul 09 Python
Python sorted函数详解(高级篇)
Sep 18 Python
浅谈python3发送post请求参数为空的情况
Dec 28 Python
Django JWT Token RestfulAPI用户认证详解
Jan 23 Python
使用TensorFlow直接获取处理MNIST数据方式
Feb 10 Python
python GUI库图形界面开发之PyQt5开发环境配置与基础使用
Feb 25 Python
python生成并处理uuid的实现方式
Mar 03 Python
Python异常类型以及处理方法汇总
Jun 05 Python
浅析Python实现DFA算法
Jun 26 Python
详解Python开发中如何使用Hook技巧
Nov 01 #Python
python利用标准库如何获取本地IP示例详解
Nov 01 #Python
你眼中的Python大牛 应该都有这份书单
Oct 31 #Python
Python生成数字图片代码分享
Oct 31 #Python
python使用标准库根据进程名如何获取进程的pid详解
Oct 31 #Python
Python列表删除的三种方法代码分享
Oct 31 #Python
Python文件的读写和异常代码示例
Oct 31 #Python
You might like
PHP文件读写操作之文件读取方法详解
2011/01/13 PHP
PHP面向对象学习笔记之一 基础概念
2012/10/06 PHP
php读取目录所有文件信息dir示例
2014/03/18 PHP
php pdo oracle中文乱码的快速解决方法
2016/05/16 PHP
PHP实现bitmap位图排序与求交集的方法
2016/07/28 PHP
phpMyAdmin无法登陆的解决方法
2017/04/27 PHP
php精度计算的问题解析
2019/06/21 PHP
Laravel实现ORM带条件搜索分页
2019/10/24 PHP
下拉列表选择项的选中在不同浏览器中的兼容性问题探讨
2013/09/18 Javascript
jquery延迟对象解析
2016/10/26 Javascript
JS文件/图片从电脑里面拖拽到浏览器上传文件/图片
2017/03/08 Javascript
vue学习笔记之vue1.0和vue2.0的区别介绍
2017/05/17 Javascript
Vue-resource拦截器判断token失效跳转的实例
2017/10/27 Javascript
微信小程序用户位置权限的获取方法(拒绝后提醒)
2018/11/15 Javascript
js实现搜索栏效果
2018/11/16 Javascript
Vue3.0结合bootstrap创建多页面应用
2019/05/28 Javascript
js字符串类型String常用操作实例总结
2019/07/05 Javascript
vue 实现单选框设置默认选中值
2019/11/07 Javascript
jQuery实现小火箭返回顶部特效
2020/02/03 jQuery
深入浅析golang zap 日志库使用(含文件切割、分级别存储和全局使用等)
2020/02/19 Javascript
浅谈Vue static 静态资源路径 和 style问题
2020/11/07 Javascript
解析Python中的eval()、exec()及其相关函数
2017/12/20 Python
numpy.ndarray 交换多维数组(矩阵)的行/列方法
2018/08/02 Python
python中强大的format函数实例详解
2018/12/05 Python
Python 编程速成(推荐)
2019/04/15 Python
python flask 如何修改默认端口号的方法步骤
2019/07/12 Python
基于Python的一个自动录入表格的小程序
2020/08/05 Python
css3弹性盒模型(Flexbox)详细介绍
2014/10/08 HTML / CSS
美国男士和女士奢侈品折扣手表购物网站:Certified Watch Store
2018/06/13 全球购物
印尼极简主义和实惠的在线家具店:Fabelio
2019/03/27 全球购物
Lacoste(法国鳄鱼)加拿大官网:以标志性的POLO衫而闻名
2019/05/15 全球购物
莫斯科购买书籍网站:Book24
2020/01/12 全球购物
琳达·法罗眼镜英国官网:Linda Farrow英国
2021/01/19 全球购物
Vinatis德国:法国领先的葡萄酒邮购公司
2020/09/07 全球购物
Why we need EJB
2016/10/20 面试题
2014国庆节演讲稿:祖国在我心中(400字)
2014/09/25 职场文书