决策树剪枝算法的python实现方法详解


Posted in Python onSeptember 18, 2019

本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典

def calcShannonEnt(dataSet): 
  numEntries = len(dataSet) 
  labelCounts = {} 
  # 给所有可能分类创建字典 
  for featVec in dataSet: 
    currentLabel = featVec[-1] 
    if currentLabel not in labelCounts.keys(): 
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  # 以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key]) / numEntries
    shannonEnt -= prob * log(prob, 2)
  return shannonEnt
# 对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis + 1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
  numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
``def classify(inputTree, featLabels, testVec):
  firstStr = inputTree.keys()[0]
  if u'<=' in firstStr:
    featvalue = float(firstStr.split(u"<=")[1])
    featkey = firstStr.split(u"<=")[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(featkey)
    if testVec[featIndex] <= featvalue:
      judge = 1
    else:
      judge = 0
    for key in secondDict.keys():
      if judge == int(key):
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  else:
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
      if testVec[featIndex] == key:
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  return classLabel
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
  class_list = [example[-1] for example in train_data]
  bestFeatIndex = labels.index(feat)
  train_data = [example[bestFeatIndex] for example in train_data]
  test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
  all_feat = set(train_data)
  error = 0.0
  for value in all_feat:
    class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
    major = majorityCnt(class_feat)
    for data in test_data:
      if data[0] == value and data[1] != major:
        error += 1.0
  # print 'myTree %d' % error
  return error

测试

error = 0.0
  for i in range(len(data_test)):
    if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
      error += 1
  # print 'myTree %d' % error
  return float(error)
def testingMajor(major, data_test):
  error = 0.0
  for i in range(len(data_test)):
    if major != data_test[i][-1]:
      error += 1
  # print 'major %d' % error
  return float(error)

**递归产生决策树**

```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  labels_copy = copy.deepcopy(labels)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]

  if mode == "unpro" or mode == "post":
    myTree = {bestFeatLabel: {}}
  elif mode == "prev":
    if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
                                            test_data):
      myTree = {bestFeatLabel: {}}
    else:
      return majorityCnt(classList)
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)

  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    currentlabel = labels_full.index(labels[bestFeat])
    featValuesFull = [example[currentlabel] for example in data_full]
    uniqueValsFull = set(featValuesFull)

  del (labels[bestFeat])

  for value in uniqueVals:
    subLabels = labels[:]
    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
      uniqueValsFull.remove(value)

    myTree[bestFeatLabel][value] = createTree(splitDataSet \
                           (dataSet, bestFeat, value), subLabels, data_full, labels_full,
                         splitDataSet \
                           (test_data, bestFeat, value), mode=mode)
  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value] = majorityCnt(classList)

  if mode == "post":
    if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
      return majorityCnt(classList)
  return myTree








<div class="se-preview-section-delimiter"></div>

```**读入数据**

```def load_data(file_name):
  with open(r"dd.csv", 'rb') as f:
   df = pd.read_csv(f,sep=",")
   print(df)
   train_data = df.values[:11, 1:].tolist()
  print(train_data)
  test_data = df.values[11:, 1:].tolist()
  labels = df.columns.values[1:-1].tolist()
  return train_data, test_data, labels





<div class="se-preview-section-delimiter"></div>

```测试并绘制树图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头


# 计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      numLeafs += getNumLeafs(secondDict[key])
    else:
      numLeafs += 1
  return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
  maxDepth = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
      thisDepth = 1
    if thisDepth > maxDepth:
      maxDepth = thisDepth
  return maxDepth


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
              xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
              bbox=nodeType, arrowprops=arrow_args)


# 画箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
  lens = len(txtString)
  xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
  yMid = (parentPt[1] + cntrPt[1]) / 2.0
  createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
  numLeafs = getNumLeafs(myTree)
  depth = getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
  plotMidText(cntrPt, parentPt, nodeTxt)
  plotNode(firstStr, cntrPt, parentPt, decisionNode)
  secondDict = myTree[firstStr]
  plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      plotTree(secondDict[key], cntrPt, str(key))
    else:
      plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
      plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
      plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
  plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
  fig = plt.figure(1, facecolor='white')
  fig.clf()
  axprops = dict(xticks=[], yticks=[])
  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  plotTree.totalW = float(getNumLeafs(inTree))
  plotTree.totalD = float(getTreeDepth(inTree))
  plotTree.x0ff = -0.5 / plotTree.totalW
  plotTree.y0ff = 1.0
  plotTree(inTree, (0.5, 1.0), '')
  plt.show()
if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图
决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python库urllib与urllib2主要区别分析
Jul 13 Python
python使用urllib2实现发送带cookie的请求
Apr 28 Python
Python求两个文本文件以行为单位的交集、并集与差集的方法
Jun 17 Python
Go语言基于Socket编写服务器端与客户端通信的实例
Feb 19 Python
解决django前后端分离csrf验证的问题
Feb 03 Python
详解Python中的内建函数,可迭代对象,迭代器
Apr 29 Python
Python实现序列化及csv文件读取
Jan 19 Python
利用python3筛选excel中特定的行(行值满足某个条件/行值属于某个集合)
Sep 04 Python
Python类的继承super相关原理解析
Oct 22 Python
jupyter使用自动补全和切换默认浏览器的方法
Nov 18 Python
Python list去重且保持原顺序不变的方法
Apr 03 Python
如何在C++中调用Python
May 21 Python
python生成requirements.txt的两种方法
Sep 18 #Python
python2与python3爬虫中get与post对比解析
Sep 18 #Python
python中class的定义及使用教程
Sep 18 #Python
django创建超级用户过程解析
Sep 18 #Python
python实现网站微信登录的示例代码
Sep 18 #Python
简单了解python中的与或非运算
Sep 18 #Python
python字符串替换re.sub()方法解析
Sep 18 #Python
You might like
2019十大人气国漫
2020/03/13 国漫
php 图片上添加透明度渐变的效果
2009/06/29 PHP
php 代码优化之经典示例
2011/03/24 PHP
Linux下php5.4启动脚本
2014/08/03 PHP
php上传文件常见问题总结
2015/02/03 PHP
初探jquery——表单应用范例
2007/02/20 Javascript
使用prototype.js 的时候应该特别注意的几个问题.
2007/04/12 Javascript
JS高级笔记
2011/07/13 Javascript
五段实用的js高级技巧
2011/12/20 Javascript
jQuery中closest()函数用法实例
2015/01/07 Javascript
原生态js,鼠标按下后,经过了那些单元格的简单实例
2016/08/11 Javascript
Bootstrap框架实现广告轮播效果
2016/11/28 Javascript
jQuery网页定位导航特效实现方法
2016/12/19 Javascript
js实现图片360度旋转
2017/01/22 Javascript
微信小程序 rich-text的使用方法
2017/08/04 Javascript
简单实现jQuery上传图片显示预览功能
2020/06/29 jQuery
基于jQuery实现图片推拉门动画效果的两种方法
2017/08/26 jQuery
信息滚动效果的实例讲解
2017/09/18 Javascript
Angular4的输入属性与输出属性实例详解
2017/11/29 Javascript
讲解vue-router之什么是嵌套路由
2018/05/28 Javascript
vue打包之后生成一个配置文件修改接口的方法
2018/12/09 Javascript
使用Vue Composition API写出清晰、可扩展的表单实现
2020/06/10 Javascript
Vue左滑组件slider使用详解
2020/08/21 Javascript
Python基于identicon库创建类似Github上用的头像功能
2017/09/25 Python
详解Python判定IP地址合法性的三种方法
2018/03/06 Python
python 实现得到当前时间偏移day天后的日期方法
2018/12/31 Python
在OpenCV里使用Camshift算法的实现
2019/11/22 Python
Django app配置多个数据库代码实例
2019/12/17 Python
mac使用python识别图形验证码功能
2020/01/10 Python
联想墨西哥官方网站:Lenovo墨西哥
2016/08/17 全球购物
旧时光糖果:Old Time Candy
2018/02/05 全球购物
彪马荷兰官网:PUMA荷兰
2019/05/08 全球购物
实习生自荐信范文
2013/11/13 职场文书
群众路线四风自我剖析材料
2014/10/08 职场文书
5.12护士节活动总结
2015/02/10 职场文书
会计求职自荐信
2015/03/26 职场文书