Python机器学习之决策树算法实例详解


Posted in Python onDecember 06, 2017

本文实例讲述了Python机器学习之决策树算法。分享给大家供大家参考,具体如下:

决策树学习是应用最广泛的归纳推理算法之一,是一种逼近离散值目标函数的方法,在这种方法中学习到的函数被表示为一棵决策树。决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些从数据集中创造的规则。决策树的优点为:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。缺点为:可能产生过度匹配的问题。决策树适于处理离散型和连续型的数据。

在决策树中最重要的就是如何选取用于划分的特征

在算法中一般选用ID3,D3算法的核心问题是选取在树的每个节点要测试的特征或者属性,希望选择的是最有助于分类实例的属性。如何定量地衡量一个属性的价值呢?这里需要引入熵和信息增益的概念。熵是信息论中广泛使用的一个度量标准,刻画了任意样本集的纯度。

假设有10个训练样本,其中6个的分类标签为yes,4个的分类标签为no,那熵是多少呢?在该例子中,分类的数目为2(yes,no),yes的概率为0.6,no的概率为0.4,则熵为 :

Python机器学习之决策树算法实例详解

Python机器学习之决策树算法实例详解

其中value(A)是属性A所有可能值的集合,Python机器学习之决策树算法实例详解是S中属性A的值为v的子集,即Python机器学习之决策树算法实例详解。上述公式的第一项为原集合S的熵,第二项是用A分类S后熵的期望值,该项描述的期望熵就是每个子集的熵的加权和,权值为属于的样本占原始样本S的比例Python机器学习之决策树算法实例详解。所以Gain(S, A)是由于知道属性A的值而导致的期望熵减少。

完整的代码:

# -*- coding: cp936 -*-
from numpy import *
import operator
from math import log
import operator
def createDataSet():
  dataSet = [[1,1,'yes'],
    [1,1,'yes'],
    [1,0,'no'],
    [0,1,'no'],
    [0,1,'no']]
  labels = ['no surfacing','flippers']
  return dataSet, labels
def calcShannonEnt(dataSet):
  numEntries = len(dataSet)
  labelCounts = {} # a dictionary for feature
  for featVec in dataSet:
    currentLabel = featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  for key in labelCounts:
    #print(key)
    #print(labelCounts[key])
    prob = float(labelCounts[key])/numEntries
    #print(prob)
    shannonEnt -= prob * log(prob,2)
  return shannonEnt
#按照给定的特征划分数据集
#根据axis等于value的特征将数据提出
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet
#选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSplit(dataSet):
  numFeatures = len(dataSet[0]) - 1 #剩下的是特征的个数
  baseEntropy = calcShannonEnt(dataSet)#计算数据集的熵,放到baseEntropy中
  bestInfoGain = 0.0;bestFeature = -1 #初始化熵增益
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet] #featList存储对应特征所有可能得取值
    uniqueVals = set(featList)
    newEntropy = 0.0
    for value in uniqueVals:#下面是计算每种划分方式的信息熵,特征i个,每个特征value个值
      subDataSet = splitDataSet(dataSet, i ,value)
      prob = len(subDataSet)/float(len(dataSet)) #特征样本在总样本中的权重
      newEntropy = prob * calcShannonEnt(subDataSet)
    infoGain = baseEntropy - newEntropy #计算i个特征的信息熵
    #print(i)
    #print(infoGain)
    if(infoGain > bestInfoGain):
      bestInfoGain = infoGain
      bestFeature = i
  return bestFeature
#如上面是决策树所有的功能模块
#得到原始数据集之后基于最好的属性值进行划分,每一次划分之后传递到树分支的下一个节点
#递归结束的条件是程序遍历完成所有的数据集属性,或者是每一个分支下的所有实例都具有相同的分类
#如果所有实例具有相同的分类,则得到一个叶子节点或者终止快
#如果所有属性都已经被处理,但是类标签依然不是确定的,那么采用多数投票的方式
#返回出现次数最多的分类名称
def majorityCnt(classList):
  classCount = {}
  for vote in classList:
    if vote not in classCount.keys():classCount[vote] = 0
    classCount[vote] += 1
  sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
  return sortedClassCount[0][0]
#创建决策树
def createTree(dataSet,labels):
  classList = [example[-1] for example in dataSet]#将最后一行的数据放到classList中,所有的类别的值
  if classList.count(classList[0]) == len(classList): #类别完全相同不需要再划分
    return classList[0]
  if len(dataSet[0]) == 1:#这里为什么是1呢?就是说特征数为1的时候
    return majorityCnt(classList)#就返回这个特征就行了,因为就这一个特征
  bestFeat = chooseBestFeatureToSplit(dataSet)
  print('the bestFeatue in creating is :')
  print(bestFeat)
  bestFeatLabel = labels[bestFeat]#运行结果'no surfacing'
  myTree = {bestFeatLabel:{}}#嵌套字典,目前value是一个空字典
  del(labels[bestFeat])
  featValues = [example[bestFeat] for example in dataSet]#第0个特征对应的取值
  uniqueVals = set(featValues)
  for value in uniqueVals: #根据当前特征值的取值进行下一级的划分
    subLabels = labels[:]
    myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  return myTree
#对上面简单的数据进行小测试
def testTree1():
  myDat,labels=createDataSet()
  val = calcShannonEnt(myDat)
  print 'The classify accuracy is: %.2f%%' % val
  retDataSet1 = splitDataSet(myDat,0,1)
  print (myDat)
  print(retDataSet1)
  retDataSet0 = splitDataSet(myDat,0,0)
  print (myDat)
  print(retDataSet0)
  bestfeature = chooseBestFeatureToSplit(myDat)
  print('the bestFeatue is :')
  print(bestfeature)
  tree = createTree(myDat,labels)
  print(tree)

对应的结果是:

>>> import TREE
>>> TREE.testTree1()
The classify accuracy is: 0.97%
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'no'], [1, 'no']]
the bestFeatue is :
0
the bestFeatue in creating is :
0
the bestFeatue in creating is :
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

最好再增加使用决策树的分类函数

同时因为构建决策树是非常耗时间的,因为最好是将构建好的树通过 python 的 pickle 序列化对象,将对象保存在磁盘上,等到需要用的时候再读出

def classify(inputTree,featLabels,testVec):
  firstStr = inputTree.keys()[0]
  secondDict = inputTree[firstStr]
  featIndex = featLabels.index(firstStr)
  key = testVec[featIndex]
  valueOfFeat = secondDict[key]
  if isinstance(valueOfFeat, dict):
    classLabel = classify(valueOfFeat, featLabels, testVec)
  else: classLabel = valueOfFeat
  return classLabel
def storeTree(inputTree,filename):
  import pickle
  fw = open(filename,'w')
  pickle.dump(inputTree,fw)
  fw.close()
def grabTree(filename):
  import pickle
  fr = open(filename)
  return pickle.load(fr)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
详谈Python2.6和Python3.0中对除法操作的异同
Apr 28 Python
详解python基础之while循环及if判断
Aug 24 Python
Python基础语言学习笔记总结(精华)
Nov 14 Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 Python
python基于http下载视频或音频
Jun 20 Python
Python实现简单的用户交互方法详解
Sep 25 Python
Python FTP文件定时自动下载实现过程解析
Nov 12 Python
TensorFlow索引与切片的实现方法
Nov 20 Python
关于ResNeXt网络的pytorch实现
Jan 14 Python
GDAL 矢量属性数据修改方式(python)
Mar 10 Python
Python接口测试文件上传实例解析
May 22 Python
python 密码学示例——理解哈希(Hash)算法
Sep 21 Python
快速入门python学习笔记
Dec 06 #Python
Python中django学习心得
Dec 06 #Python
Python标准库inspect的具体使用方法
Dec 06 #Python
读取本地json文件,解析json(实例讲解)
Dec 06 #Python
Python语言描述最大连续子序列和
Dec 05 #Python
python matplotlib坐标轴设置的方法
Dec 05 #Python
详解K-means算法在Python中的实现
Dec 05 #Python
You might like
php摘要生成函数(无乱码)
2012/02/04 PHP
解决ThinkPHP下使用上传插件Uploadify浏览器firefox报302错误的方法
2015/12/18 PHP
PHP/HTML混写的四种方式总结
2017/02/27 PHP
PHP实现简单的模板引擎功能示例
2017/09/02 PHP
JS 文字符串转换unicode编码函数
2009/05/30 Javascript
jQuery 使用手册(一)
2009/09/23 Javascript
JavaScript对HTML DOM使用EventListener进行操作
2015/10/21 Javascript
第四篇Bootstrap网格系统偏移列和嵌套列
2016/06/21 Javascript
JavaScript实现刷新不重记的倒计时
2016/08/10 Javascript
js中的eval()函数把含有转义字符的字符串转换成Object对象的方法
2016/12/02 Javascript
Angularjs使用ng-repeat中$even和$odd属性的注意事项
2016/12/31 Javascript
AngularJS1.X学习笔记2-数据绑定详解
2017/04/01 Javascript
jQuery返回定位插件详解
2017/05/15 jQuery
全面解析Node.js 8 重要功能和修复
2017/06/02 Javascript
微信小程序组件 marquee实例详解
2017/06/23 Javascript
Vue波纹按钮组件制作
2018/04/30 Javascript
angular2/ionic2 实现搜索结果中的搜索关键字高亮的示例
2018/08/17 Javascript
微信小程序定义和调用全局变量globalData的实现
2019/11/01 Javascript
微信小程序组件生命周期的踩坑记录
2021/03/03 Javascript
linux系统使用python获取cpu信息脚本分享
2014/01/15 Python
python按照多个字符对字符串进行分割的方法
2015/03/17 Python
Python 正则表达式实现计算器功能
2017/04/29 Python
django DRF图片路径问题的解决方法
2018/09/10 Python
python一键去抖音视频水印工具
2018/09/14 Python
详解Python 切片语法
2019/06/10 Python
python 实现屏幕录制示例
2019/12/23 Python
HTML5 DeviceOrientation实现手机网站摇一摇功能代码实例
2015/04/24 HTML / CSS
德国鞋子网上商店:Omoda.de
2017/03/31 全球购物
泰国第一的化妆品网站:Konvy
2018/02/25 全球购物
Ibatis如何使用动态表名
2015/07/12 面试题
在浏览器端如何得到服务器端响应的XML数据
2012/11/24 面试题
年会活动策划方案
2014/01/23 职场文书
公共艺术专业自荐信
2014/09/01 职场文书
安全生产学习心得体会
2016/01/18 职场文书
JavaScript offset实现鼠标坐标获取和窗口内模块拖动
2021/05/30 Javascript
JavaScript最完整的深浅拷贝实现方式详解
2022/02/28 Javascript