Python机器学习之决策树算法实例详解


Posted in Python onDecember 06, 2017

本文实例讲述了Python机器学习之决策树算法。分享给大家供大家参考,具体如下:

决策树学习是应用最广泛的归纳推理算法之一,是一种逼近离散值目标函数的方法,在这种方法中学习到的函数被表示为一棵决策树。决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些从数据集中创造的规则。决策树的优点为:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。缺点为:可能产生过度匹配的问题。决策树适于处理离散型和连续型的数据。

在决策树中最重要的就是如何选取用于划分的特征

在算法中一般选用ID3,D3算法的核心问题是选取在树的每个节点要测试的特征或者属性,希望选择的是最有助于分类实例的属性。如何定量地衡量一个属性的价值呢?这里需要引入熵和信息增益的概念。熵是信息论中广泛使用的一个度量标准,刻画了任意样本集的纯度。

假设有10个训练样本,其中6个的分类标签为yes,4个的分类标签为no,那熵是多少呢?在该例子中,分类的数目为2(yes,no),yes的概率为0.6,no的概率为0.4,则熵为 :

Python机器学习之决策树算法实例详解

Python机器学习之决策树算法实例详解

其中value(A)是属性A所有可能值的集合,Python机器学习之决策树算法实例详解是S中属性A的值为v的子集,即Python机器学习之决策树算法实例详解。上述公式的第一项为原集合S的熵,第二项是用A分类S后熵的期望值,该项描述的期望熵就是每个子集的熵的加权和,权值为属于的样本占原始样本S的比例Python机器学习之决策树算法实例详解。所以Gain(S, A)是由于知道属性A的值而导致的期望熵减少。

完整的代码:

# -*- coding: cp936 -*-
from numpy import *
import operator
from math import log
import operator
def createDataSet():
  dataSet = [[1,1,'yes'],
    [1,1,'yes'],
    [1,0,'no'],
    [0,1,'no'],
    [0,1,'no']]
  labels = ['no surfacing','flippers']
  return dataSet, labels
def calcShannonEnt(dataSet):
  numEntries = len(dataSet)
  labelCounts = {} # a dictionary for feature
  for featVec in dataSet:
    currentLabel = featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  for key in labelCounts:
    #print(key)
    #print(labelCounts[key])
    prob = float(labelCounts[key])/numEntries
    #print(prob)
    shannonEnt -= prob * log(prob,2)
  return shannonEnt
#按照给定的特征划分数据集
#根据axis等于value的特征将数据提出
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet
#选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSplit(dataSet):
  numFeatures = len(dataSet[0]) - 1 #剩下的是特征的个数
  baseEntropy = calcShannonEnt(dataSet)#计算数据集的熵,放到baseEntropy中
  bestInfoGain = 0.0;bestFeature = -1 #初始化熵增益
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet] #featList存储对应特征所有可能得取值
    uniqueVals = set(featList)
    newEntropy = 0.0
    for value in uniqueVals:#下面是计算每种划分方式的信息熵,特征i个,每个特征value个值
      subDataSet = splitDataSet(dataSet, i ,value)
      prob = len(subDataSet)/float(len(dataSet)) #特征样本在总样本中的权重
      newEntropy = prob * calcShannonEnt(subDataSet)
    infoGain = baseEntropy - newEntropy #计算i个特征的信息熵
    #print(i)
    #print(infoGain)
    if(infoGain > bestInfoGain):
      bestInfoGain = infoGain
      bestFeature = i
  return bestFeature
#如上面是决策树所有的功能模块
#得到原始数据集之后基于最好的属性值进行划分,每一次划分之后传递到树分支的下一个节点
#递归结束的条件是程序遍历完成所有的数据集属性,或者是每一个分支下的所有实例都具有相同的分类
#如果所有实例具有相同的分类,则得到一个叶子节点或者终止快
#如果所有属性都已经被处理,但是类标签依然不是确定的,那么采用多数投票的方式
#返回出现次数最多的分类名称
def majorityCnt(classList):
  classCount = {}
  for vote in classList:
    if vote not in classCount.keys():classCount[vote] = 0
    classCount[vote] += 1
  sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
  return sortedClassCount[0][0]
#创建决策树
def createTree(dataSet,labels):
  classList = [example[-1] for example in dataSet]#将最后一行的数据放到classList中,所有的类别的值
  if classList.count(classList[0]) == len(classList): #类别完全相同不需要再划分
    return classList[0]
  if len(dataSet[0]) == 1:#这里为什么是1呢?就是说特征数为1的时候
    return majorityCnt(classList)#就返回这个特征就行了,因为就这一个特征
  bestFeat = chooseBestFeatureToSplit(dataSet)
  print('the bestFeatue in creating is :')
  print(bestFeat)
  bestFeatLabel = labels[bestFeat]#运行结果'no surfacing'
  myTree = {bestFeatLabel:{}}#嵌套字典,目前value是一个空字典
  del(labels[bestFeat])
  featValues = [example[bestFeat] for example in dataSet]#第0个特征对应的取值
  uniqueVals = set(featValues)
  for value in uniqueVals: #根据当前特征值的取值进行下一级的划分
    subLabels = labels[:]
    myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  return myTree
#对上面简单的数据进行小测试
def testTree1():
  myDat,labels=createDataSet()
  val = calcShannonEnt(myDat)
  print 'The classify accuracy is: %.2f%%' % val
  retDataSet1 = splitDataSet(myDat,0,1)
  print (myDat)
  print(retDataSet1)
  retDataSet0 = splitDataSet(myDat,0,0)
  print (myDat)
  print(retDataSet0)
  bestfeature = chooseBestFeatureToSplit(myDat)
  print('the bestFeatue is :')
  print(bestfeature)
  tree = createTree(myDat,labels)
  print(tree)

对应的结果是:

>>> import TREE
>>> TREE.testTree1()
The classify accuracy is: 0.97%
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'no'], [1, 'no']]
the bestFeatue is :
0
the bestFeatue in creating is :
0
the bestFeatue in creating is :
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

最好再增加使用决策树的分类函数

同时因为构建决策树是非常耗时间的,因为最好是将构建好的树通过 python 的 pickle 序列化对象,将对象保存在磁盘上,等到需要用的时候再读出

def classify(inputTree,featLabels,testVec):
  firstStr = inputTree.keys()[0]
  secondDict = inputTree[firstStr]
  featIndex = featLabels.index(firstStr)
  key = testVec[featIndex]
  valueOfFeat = secondDict[key]
  if isinstance(valueOfFeat, dict):
    classLabel = classify(valueOfFeat, featLabels, testVec)
  else: classLabel = valueOfFeat
  return classLabel
def storeTree(inputTree,filename):
  import pickle
  fw = open(filename,'w')
  pickle.dump(inputTree,fw)
  fw.close()
def grabTree(filename):
  import pickle
  fr = open(filename)
  return pickle.load(fr)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
跟老齐学Python之编写类之一创建实例
Oct 11 Python
介绍Python中的一些高级编程技巧
Apr 02 Python
python实现写数字文件名的递增保存文件方法
Oct 25 Python
详解Python3中ceil()函数用法
Feb 19 Python
python scatter散点图用循环分类法加图例
Mar 19 Python
pybind11和numpy进行交互的方法
Jul 04 Python
Django框架 信号调度原理解析
Sep 04 Python
在django中使用post方法时,需要增加csrftoken的例子
Mar 13 Python
Python select及selectors模块概念用法详解
Jun 22 Python
简单了解python关键字global nonlocal区别
Sep 21 Python
Pytorch 统计模型参数量的操作 param.numel()
May 13 Python
Pandas自定义选项option设置
Jul 25 Python
快速入门python学习笔记
Dec 06 #Python
Python中django学习心得
Dec 06 #Python
Python标准库inspect的具体使用方法
Dec 06 #Python
读取本地json文件,解析json(实例讲解)
Dec 06 #Python
Python语言描述最大连续子序列和
Dec 05 #Python
python matplotlib坐标轴设置的方法
Dec 05 #Python
详解K-means算法在Python中的实现
Dec 05 #Python
You might like
在Linux系统的服务器上隐藏PHP版本号的方法
2015/06/06 PHP
浅谈PHP中new self()和new static()的区别
2017/08/11 PHP
JS getStyle获取最终样式函数代码
2010/04/01 Javascript
十个迅速提升JQuery性能让你的JQuery跑得更快
2012/12/10 Javascript
关于JS管理作用域的问题
2013/04/10 Javascript
javasctipt如何显示几分钟前、几天前等
2014/04/30 Javascript
js判断变量初始化的三种形式及推荐用的形式
2014/07/22 Javascript
JavaScript之Object类型介绍
2015/04/01 Javascript
解读Bootstrap v4 sass设计
2016/05/29 Javascript
javascript实现动态显示颜色块的报表效果
2017/04/10 Javascript
vue 打包后的文件部署到express服务器上的方法
2017/08/09 Javascript
详解vue-cli 本地开发mock数据使用方法
2018/05/29 Javascript
在vue项目中引用Iview的方法
2018/09/14 Javascript
详解Next.js页面渲染的优化方案
2019/01/27 Javascript
vue单文件组件无法获取$refs的问题
2020/06/24 Javascript
nuxt静态部署打包相对路径操作
2020/11/06 Javascript
Python解析并读取PDF文件内容的方法
2018/05/08 Python
numpy.std() 计算矩阵标准差的方法
2018/07/11 Python
Python统计纯文本文件中英文单词出现个数的方法总结【测试可用】
2018/07/25 Python
python中的json总结
2018/10/11 Python
python读取xlsx的方法
2018/12/25 Python
解决pycharm工程启动卡住没反应的问题
2019/01/19 Python
Python安装Flask环境及简单应用示例
2019/05/03 Python
keras导入weights方式
2020/06/12 Python
python 读取.nii格式图像实例
2020/07/01 Python
有关HTML5 Video对象的ontimeupdate事件(Chrome上无效)的问题
2013/07/19 HTML / CSS
html5组织内容_动力节点Java学院整理
2017/07/10 HTML / CSS
Tessabit美国:集世界奢侈品和设计师品牌的意大利精品买手店
2020/06/29 全球购物
大型活动组织方案
2014/05/10 职场文书
教师职位说明书
2014/07/29 职场文书
机械设备与数控技术专业求职信
2014/08/10 职场文书
房屋租赁授权委托书范本
2014/09/20 职场文书
《7的乘法口诀》教学反思
2016/02/18 职场文书
工人先锋号事迹材料(2016精选版)
2016/03/01 职场文书
python3操作redis实现List列表实例
2021/08/04 Python
mysql如何能有效防止删库跑路
2021/10/05 MySQL