决策树剪枝算法的python实现方法详解


Posted in Python onSeptember 18, 2019

本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典

def calcShannonEnt(dataSet): 
  numEntries = len(dataSet) 
  labelCounts = {} 
  # 给所有可能分类创建字典 
  for featVec in dataSet: 
    currentLabel = featVec[-1] 
    if currentLabel not in labelCounts.keys(): 
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  # 以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key]) / numEntries
    shannonEnt -= prob * log(prob, 2)
  return shannonEnt
# 对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis + 1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
  numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
``def classify(inputTree, featLabels, testVec):
  firstStr = inputTree.keys()[0]
  if u'<=' in firstStr:
    featvalue = float(firstStr.split(u"<=")[1])
    featkey = firstStr.split(u"<=")[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(featkey)
    if testVec[featIndex] <= featvalue:
      judge = 1
    else:
      judge = 0
    for key in secondDict.keys():
      if judge == int(key):
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  else:
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
      if testVec[featIndex] == key:
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  return classLabel
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
  class_list = [example[-1] for example in train_data]
  bestFeatIndex = labels.index(feat)
  train_data = [example[bestFeatIndex] for example in train_data]
  test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
  all_feat = set(train_data)
  error = 0.0
  for value in all_feat:
    class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
    major = majorityCnt(class_feat)
    for data in test_data:
      if data[0] == value and data[1] != major:
        error += 1.0
  # print 'myTree %d' % error
  return error

测试

error = 0.0
  for i in range(len(data_test)):
    if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
      error += 1
  # print 'myTree %d' % error
  return float(error)
def testingMajor(major, data_test):
  error = 0.0
  for i in range(len(data_test)):
    if major != data_test[i][-1]:
      error += 1
  # print 'major %d' % error
  return float(error)

**递归产生决策树**

```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  labels_copy = copy.deepcopy(labels)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]

  if mode == "unpro" or mode == "post":
    myTree = {bestFeatLabel: {}}
  elif mode == "prev":
    if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
                                            test_data):
      myTree = {bestFeatLabel: {}}
    else:
      return majorityCnt(classList)
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)

  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    currentlabel = labels_full.index(labels[bestFeat])
    featValuesFull = [example[currentlabel] for example in data_full]
    uniqueValsFull = set(featValuesFull)

  del (labels[bestFeat])

  for value in uniqueVals:
    subLabels = labels[:]
    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
      uniqueValsFull.remove(value)

    myTree[bestFeatLabel][value] = createTree(splitDataSet \
                           (dataSet, bestFeat, value), subLabels, data_full, labels_full,
                         splitDataSet \
                           (test_data, bestFeat, value), mode=mode)
  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value] = majorityCnt(classList)

  if mode == "post":
    if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
      return majorityCnt(classList)
  return myTree








<div class="se-preview-section-delimiter"></div>

```**读入数据**

```def load_data(file_name):
  with open(r"dd.csv", 'rb') as f:
   df = pd.read_csv(f,sep=",")
   print(df)
   train_data = df.values[:11, 1:].tolist()
  print(train_data)
  test_data = df.values[11:, 1:].tolist()
  labels = df.columns.values[1:-1].tolist()
  return train_data, test_data, labels





<div class="se-preview-section-delimiter"></div>

```测试并绘制树图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头


# 计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      numLeafs += getNumLeafs(secondDict[key])
    else:
      numLeafs += 1
  return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
  maxDepth = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
      thisDepth = 1
    if thisDepth > maxDepth:
      maxDepth = thisDepth
  return maxDepth


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
              xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
              bbox=nodeType, arrowprops=arrow_args)


# 画箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
  lens = len(txtString)
  xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
  yMid = (parentPt[1] + cntrPt[1]) / 2.0
  createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
  numLeafs = getNumLeafs(myTree)
  depth = getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
  plotMidText(cntrPt, parentPt, nodeTxt)
  plotNode(firstStr, cntrPt, parentPt, decisionNode)
  secondDict = myTree[firstStr]
  plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      plotTree(secondDict[key], cntrPt, str(key))
    else:
      plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
      plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
      plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
  plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
  fig = plt.figure(1, facecolor='white')
  fig.clf()
  axprops = dict(xticks=[], yticks=[])
  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  plotTree.totalW = float(getNumLeafs(inTree))
  plotTree.totalD = float(getTreeDepth(inTree))
  plotTree.x0ff = -0.5 / plotTree.totalW
  plotTree.y0ff = 1.0
  plotTree(inTree, (0.5, 1.0), '')
  plt.show()
if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图
决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
遍历python字典几种方法总结(推荐)
Sep 11 Python
python删除过期log文件操作实例解析
Jan 31 Python
Django自定义manage命令实例代码
Feb 11 Python
Python实现查看系统启动项功能示例
May 10 Python
Python面向对象类继承和组合实例分析
May 28 Python
python爬虫之urllib,伪装,超时设置,异常处理的方法
Dec 19 Python
Python3远程监控程序的实现方法
Jul 15 Python
Python面向对象之Web静态服务器
Sep 03 Python
python__name__原理及用法详解
Nov 02 Python
pycharm修改file type方式
Nov 19 Python
Python3爬虫发送请求的知识点实例
Jul 30 Python
python3实现名片管理系统(控制台版)
Nov 29 Python
python生成requirements.txt的两种方法
Sep 18 #Python
python2与python3爬虫中get与post对比解析
Sep 18 #Python
python中class的定义及使用教程
Sep 18 #Python
django创建超级用户过程解析
Sep 18 #Python
python实现网站微信登录的示例代码
Sep 18 #Python
简单了解python中的与或非运算
Sep 18 #Python
python字符串替换re.sub()方法解析
Sep 18 #Python
You might like
比较strtr, str_replace和preg_replace三个函数的效率
2013/06/26 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(十一)
2014/06/25 PHP
php采用curl模仿登录人人网发布动态的方法
2014/11/07 PHP
PHP编程文件处理类SplFileObject和SplFileInfo用法实例分析
2017/07/22 PHP
用 JSON 处理缓存
2007/04/27 Javascript
bgsound 背景音乐 的一些常用方法及特殊用法小结
2010/05/11 Javascript
Jquery知识点二 jquery下对数组的操作
2011/01/15 Javascript
基于jquery实现图片广告轮换效果代码
2011/07/07 Javascript
javascript中关于break,continue的特殊用法与介绍
2012/05/24 Javascript
基于JS实现的倒计时程序实例
2015/07/24 Javascript
node.js路径处理方法以及绝对路径详解
2021/03/04 Javascript
微信小程序 input表单与redio及下拉列表的使用实例
2017/09/20 Javascript
webpack公共组件引用路径简化小技巧
2018/06/15 Javascript
vue移动端监听滚动条高度的实现方法
2018/09/03 Javascript
Vue的状态管理vuex使用方法详解
2020/02/05 Javascript
Vue 组件复用多次自定义参数操作
2020/07/27 Javascript
本地文件上传到七牛云服务器示例(七牛云存储)
2014/01/11 Python
python计算文本文件行数的方法
2015/07/06 Python
Python实现针对中文排序的方法
2017/05/09 Python
python使用Pycharm创建一个Django项目
2018/03/05 Python
ubuntu16.04制作vim和python3的开发环境
2018/09/23 Python
浅谈python中真正关闭socket的方法
2018/12/18 Python
使用Python批量修改文件名的代码实例
2019/01/24 Python
python实现爬虫抓取小说功能示例【抓取金庸小说】
2019/08/09 Python
tensorflow实现在函数中用tf.Print输出中间值
2020/01/21 Python
详解css3中dispaly的Grid布局与Flex布局
2020/09/11 HTML / CSS
adidas旗下高尔夫装备供应商:TaylorMade Golf(泰勒梅高尔夫)
2016/08/28 全球购物
西班牙汉普顿小姐:购买帆布鞋和太阳镜
2016/10/23 全球购物
英国奢侈品网站:MatchesFashion
2016/12/16 全球购物
英国男士时尚购物网站:Stuarts London
2017/10/22 全球购物
汽车专业毕业生自荐信
2013/11/03 职场文书
自我评价的写作规则
2014/01/06 职场文书
《一个小村庄的故事》教学反思
2014/04/13 职场文书
2015年感恩节演讲稿(优选篇)
2015/03/20 职场文书
2015年基层党支部工作总结
2015/05/21 职场文书
实习报告范文
2019/07/30 职场文书