决策树剪枝算法的python实现方法详解


Posted in Python onSeptember 18, 2019

本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典

def calcShannonEnt(dataSet): 
  numEntries = len(dataSet) 
  labelCounts = {} 
  # 给所有可能分类创建字典 
  for featVec in dataSet: 
    currentLabel = featVec[-1] 
    if currentLabel not in labelCounts.keys(): 
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  # 以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key]) / numEntries
    shannonEnt -= prob * log(prob, 2)
  return shannonEnt
# 对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis + 1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
  numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
``def classify(inputTree, featLabels, testVec):
  firstStr = inputTree.keys()[0]
  if u'<=' in firstStr:
    featvalue = float(firstStr.split(u"<=")[1])
    featkey = firstStr.split(u"<=")[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(featkey)
    if testVec[featIndex] <= featvalue:
      judge = 1
    else:
      judge = 0
    for key in secondDict.keys():
      if judge == int(key):
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  else:
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
      if testVec[featIndex] == key:
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  return classLabel
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
  class_list = [example[-1] for example in train_data]
  bestFeatIndex = labels.index(feat)
  train_data = [example[bestFeatIndex] for example in train_data]
  test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
  all_feat = set(train_data)
  error = 0.0
  for value in all_feat:
    class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
    major = majorityCnt(class_feat)
    for data in test_data:
      if data[0] == value and data[1] != major:
        error += 1.0
  # print 'myTree %d' % error
  return error

测试

error = 0.0
  for i in range(len(data_test)):
    if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
      error += 1
  # print 'myTree %d' % error
  return float(error)
def testingMajor(major, data_test):
  error = 0.0
  for i in range(len(data_test)):
    if major != data_test[i][-1]:
      error += 1
  # print 'major %d' % error
  return float(error)

**递归产生决策树**

```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  labels_copy = copy.deepcopy(labels)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]

  if mode == "unpro" or mode == "post":
    myTree = {bestFeatLabel: {}}
  elif mode == "prev":
    if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
                                            test_data):
      myTree = {bestFeatLabel: {}}
    else:
      return majorityCnt(classList)
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)

  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    currentlabel = labels_full.index(labels[bestFeat])
    featValuesFull = [example[currentlabel] for example in data_full]
    uniqueValsFull = set(featValuesFull)

  del (labels[bestFeat])

  for value in uniqueVals:
    subLabels = labels[:]
    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
      uniqueValsFull.remove(value)

    myTree[bestFeatLabel][value] = createTree(splitDataSet \
                           (dataSet, bestFeat, value), subLabels, data_full, labels_full,
                         splitDataSet \
                           (test_data, bestFeat, value), mode=mode)
  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value] = majorityCnt(classList)

  if mode == "post":
    if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
      return majorityCnt(classList)
  return myTree








<div class="se-preview-section-delimiter"></div>

```**读入数据**

```def load_data(file_name):
  with open(r"dd.csv", 'rb') as f:
   df = pd.read_csv(f,sep=",")
   print(df)
   train_data = df.values[:11, 1:].tolist()
  print(train_data)
  test_data = df.values[11:, 1:].tolist()
  labels = df.columns.values[1:-1].tolist()
  return train_data, test_data, labels





<div class="se-preview-section-delimiter"></div>

```测试并绘制树图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头


# 计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      numLeafs += getNumLeafs(secondDict[key])
    else:
      numLeafs += 1
  return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
  maxDepth = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
      thisDepth = 1
    if thisDepth > maxDepth:
      maxDepth = thisDepth
  return maxDepth


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
              xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
              bbox=nodeType, arrowprops=arrow_args)


# 画箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
  lens = len(txtString)
  xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
  yMid = (parentPt[1] + cntrPt[1]) / 2.0
  createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
  numLeafs = getNumLeafs(myTree)
  depth = getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
  plotMidText(cntrPt, parentPt, nodeTxt)
  plotNode(firstStr, cntrPt, parentPt, decisionNode)
  secondDict = myTree[firstStr]
  plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      plotTree(secondDict[key], cntrPt, str(key))
    else:
      plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
      plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
      plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
  plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
  fig = plt.figure(1, facecolor='white')
  fig.clf()
  axprops = dict(xticks=[], yticks=[])
  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  plotTree.totalW = float(getNumLeafs(inTree))
  plotTree.totalD = float(getTreeDepth(inTree))
  plotTree.x0ff = -0.5 / plotTree.totalW
  plotTree.y0ff = 1.0
  plotTree(inTree, (0.5, 1.0), '')
  plt.show()
if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图
决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python dict remove数组删除(del,pop)
Mar 24 Python
以一段代码为实例快速入门Python2.7
Mar 31 Python
Python的shutil模块中文件的复制操作函数详解
Jul 05 Python
Python中你应该知道的一些内置函数
Mar 31 Python
Python3中的json模块使用详解
May 05 Python
通过Py2exe将自己的python程序打包成.exe/.app的方法
May 26 Python
用Python实现数据的透视表的方法
Nov 16 Python
python实现beta分布概率密度函数的方法
Jul 08 Python
python是否适合网页编程详解
Oct 04 Python
python numpy库linspace相同间隔采样的实现
Feb 25 Python
学习Python需要哪些工具
Sep 04 Python
Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作
May 25 Python
python生成requirements.txt的两种方法
Sep 18 #Python
python2与python3爬虫中get与post对比解析
Sep 18 #Python
python中class的定义及使用教程
Sep 18 #Python
django创建超级用户过程解析
Sep 18 #Python
python实现网站微信登录的示例代码
Sep 18 #Python
简单了解python中的与或非运算
Sep 18 #Python
python字符串替换re.sub()方法解析
Sep 18 #Python
You might like
用PHP调用数据库的存贮过程!
2006/10/09 PHP
php excel类 phpExcel使用方法介绍
2010/08/21 PHP
PHP验证码类代码( 最新修改,完全定制化! )
2010/12/02 PHP
php调用dll的实例操作动画与代码分享
2012/08/14 PHP
php 策略模式原理与应用深入理解
2019/09/25 PHP
JAVASCRIPT  THIS详解 面向对象
2009/03/25 Javascript
jQuery 学习入门篇附实例代码
2010/03/16 Javascript
在Ubuntu系统上安装Ghost博客平台的教程
2015/06/17 Javascript
纯HTML5制作围住神经猫游戏-附源码下载
2015/08/23 Javascript
JS右下角广告窗口代码(可收缩、展开及关闭)
2015/09/04 Javascript
AngularJS ng-style中使用filter
2016/09/21 Javascript
JS多文件上传的实例代码
2017/01/11 Javascript
Javascript继承机制详解
2017/05/30 Javascript
Node.js学习之TCP/IP数据通讯(实例讲解)
2017/10/11 Javascript
nodeJS微信分享
2017/12/20 NodeJs
vue写一个组件
2018/04/09 Javascript
详解开发react应用最好用的脚手架 create-react-app
2018/04/24 Javascript
详解基于Koa2开发微信二维码扫码支付相关流程
2018/05/16 Javascript
如何通过setTimeout理解JS运行机制详解
2019/03/23 Javascript
[51:36]Optic vs Newbee 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
Python开发之基于模板匹配的信用卡数字识别功能
2020/01/13 Python
Python3+selenium配置常见报错解决方案
2020/08/28 Python
python 调整图片亮度的示例
2020/12/03 Python
html5 worker 实例(二) 图片变换效果
2013/06/24 HTML / CSS
HTML5实现无刷新修改URL的方法
2019/11/14 HTML / CSS
Electrolux伊莱克斯巴西商店:家用电器、小家电和配件
2018/05/23 全球购物
美术专业学生个人自我评价
2013/09/19 职场文书
读群众路线心得体会
2014/03/07 职场文书
安全协议书范本
2014/04/21 职场文书
给校长的建议书100字
2014/05/16 职场文书
旅游与酒店管理专业求职信
2014/07/21 职场文书
个人房屋转让协议书范本
2014/10/26 职场文书
三方股东合作协议书
2014/10/28 职场文书
人间正道是沧桑观后感
2015/06/15 职场文书
详解Vue slot插槽
2021/11/20 Vue.js
MySQL七大JOIN的具体使用
2022/02/28 MySQL