python编写分类决策树的代码


Posted in Python onDecember 21, 2017

决策树通常在机器学习中用于分类。

优点:计算复杂度不高,输出结果易于理解,对中间值缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型。

1.信息增益

划分数据集的目的是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息。通常采用信息增益,信息增益是指数据划分前后信息熵的减少值。信息越无序信息熵越大,获得信息增益最高的特征就是最好的选择。
熵定义为信息的期望,符号xi的信息定义为:

python编写分类决策树的代码

其中p(xi)为该分类的概率。
熵,即信息的期望值为:

python编写分类决策树的代码

计算信息熵的代码如下:

def calcShannonEnt(dataSet):
  numEntries = len(dataSet)
  labelCounts = {}
  for featVec in dataSet:
    currentLabel = featVec[-1]
    if currentLabel not in labelCounts:
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0
  for key in labelCounts:
    shannonEnt = shannonEnt - (labelCounts[key]/numEntries)*math.log2(labelCounts[key]/numEntries)
  return shannonEnt

可以根据信息熵,按照获取最大信息增益的方法划分数据集。

2.划分数据集

划分数据集就是将所有符合要求的元素抽出来。

def splitDataSet(dataSet,axis,value):
  retDataset = []
  for featVec in dataSet:
    if featVec[axis] == value:
      newVec = featVec[:axis]
      newVec.extend(featVec[axis+1:])
      retDataset.append(newVec)
  return retDataset

3.选择最好的数据集划分方式

信息增益是熵的减少或者是信息无序度的减少。

def chooseBestFeatureToSplit(dataSet):
  numFeatures = len(dataSet[0]) - 1
  bestInfoGain = 0
  bestFeature = -1
  baseEntropy = calcShannonEnt(dataSet)
  for i in range(numFeatures):
    allValue = [example[i] for example in dataSet]#列表推倒,创建新的列表
    allValue = set(allValue)#最快得到列表中唯一元素值的方法
    newEntropy = 0
    for value in allValue:
      splitset = splitDataSet(dataSet,i,value)
      newEntropy = newEntropy + len(splitset)/len(dataSet)*calcShannonEnt(splitset)
    infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  return bestFeature

4.递归创建决策树

结束条件为:程序遍历完所有划分数据集的属性,或每个分支下的所有实例都具有相同的分类。
当数据集已经处理了所有属性,但是类标签还不唯一时,采用多数表决的方式决定叶子节点的类型。

def majorityCnt(classList):
 classCount = {}
 for value in classList:
  if value not in classCount: classCount[value] = 0
  classCount[value] += 1
 classCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
 return classCount[0][0]

生成决策树:

def createTree(dataSet,labels):
 classList = [example[-1] for example in dataSet]
 labelsCopy = labels[:]
 if classList.count(classList[0]) == len(classList):
  return classList[0]
 if len(dataSet[0]) == 1:
  return majorityCnt(classList)
 bestFeature = chooseBestFeatureToSplit(dataSet)
 bestLabel = labelsCopy[bestFeature]
 myTree = {bestLabel:{}}
 featureValues = [example[bestFeature] for example in dataSet]
 featureValues = set(featureValues)
 del(labelsCopy[bestFeature])
 for value in featureValues:
  subLabels = labelsCopy[:]
  myTree[bestLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
 return myTree

5.测试算法——使用决策树分类

同样采用递归的方式得到分类结果。

def classify(inputTree,featLabels,testVec):
 currentFeat = list(inputTree.keys())[0]
 secondTree = inputTree[currentFeat]
 try:
  featureIndex = featLabels.index(currentFeat)
 except ValueError as err:
  print('yes')
 try:
  for value in secondTree.keys():
   if value == testVec[featureIndex]:
    if type(secondTree[value]).__name__ == 'dict':
     classLabel = classify(secondTree[value],featLabels,testVec)
    else:
     classLabel = secondTree[value]
  return classLabel
 except AttributeError:
  print(secondTree)

6.完整代码如下

import numpy as np
import math
import operator
def createDataSet():
 dataSet = [[1,1,'yes'],
    [1,1,'yes'],
    [1,0,'no'],
    [0,1,'no'],
    [0,1,'no'],]
 label = ['no surfacing','flippers']
 return dataSet,label

def calcShannonEnt(dataSet):
 numEntries = len(dataSet)
 labelCounts = {}
 for featVec in dataSet:
  currentLabel = featVec[-1]
  if currentLabel not in labelCounts:
   labelCounts[currentLabel] = 0
  labelCounts[currentLabel] += 1
 shannonEnt = 0
 for key in labelCounts:
  shannonEnt = shannonEnt - (labelCounts[key]/numEntries)*math.log2(labelCounts[key]/numEntries)
 return shannonEnt


def splitDataSet(dataSet,axis,value):
 retDataset = []
 for featVec in dataSet:
  if featVec[axis] == value:
   newVec = featVec[:axis]
   newVec.extend(featVec[axis+1:])
   retDataset.append(newVec)
 return retDataset

def chooseBestFeatureToSplit(dataSet):
 numFeatures = len(dataSet[0]) - 1
 bestInfoGain = 0
 bestFeature = -1
 baseEntropy = calcShannonEnt(dataSet)
 for i in range(numFeatures):
  allValue = [example[i] for example in dataSet]
  allValue = set(allValue)
  newEntropy = 0
  for value in allValue:
   splitset = splitDataSet(dataSet,i,value)
   newEntropy = newEntropy + len(splitset)/len(dataSet)*calcShannonEnt(splitset)
  infoGain = baseEntropy - newEntropy
  if infoGain > bestInfoGain:
   bestInfoGain = infoGain
   bestFeature = i
 return bestFeature

def majorityCnt(classList):
 classCount = {}
 for value in classList:
  if value not in classCount: classCount[value] = 0
  classCount[value] += 1
 classCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
 return classCount[0][0]   

def createTree(dataSet,labels):
 classList = [example[-1] for example in dataSet]
 labelsCopy = labels[:]
 if classList.count(classList[0]) == len(classList):
  return classList[0]
 if len(dataSet[0]) == 1:
  return majorityCnt(classList)
 bestFeature = chooseBestFeatureToSplit(dataSet)
 bestLabel = labelsCopy[bestFeature]
 myTree = {bestLabel:{}}
 featureValues = [example[bestFeature] for example in dataSet]
 featureValues = set(featureValues)
 del(labelsCopy[bestFeature])
 for value in featureValues:
  subLabels = labelsCopy[:]
  myTree[bestLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
 return myTree


def classify(inputTree,featLabels,testVec):
 currentFeat = list(inputTree.keys())[0]
 secondTree = inputTree[currentFeat]
 try:
  featureIndex = featLabels.index(currentFeat)
 except ValueError as err:
  print('yes')
 try:
  for value in secondTree.keys():
   if value == testVec[featureIndex]:
    if type(secondTree[value]).__name__ == 'dict':
     classLabel = classify(secondTree[value],featLabels,testVec)
    else:
     classLabel = secondTree[value]
  return classLabel
 except AttributeError:
  print(secondTree)

if __name__ == "__main__":
 dataset,label = createDataSet()
 myTree = createTree(dataset,label)
 a = [1,1]
 print(classify(myTree,label,a))

7.编程技巧

extend与append的区别

newVec.extend(featVec[axis+1:])
 retDataset.append(newVec)

extend([]),是将列表中的每个元素依次加入新列表中
append()是将括号中的内容当做一项加入到新列表中

列表推到

创建新列表的方式

allValue = [example[i] for example in dataSet]

提取列表中唯一的元素

allValue = set(allValue)

列表/元组排序,sorted()函数

classCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)

列表的复制

labelsCopy = labels[:]

代码及数据集下载:决策树

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python上传package到Pypi(代码简单)
Feb 06 Python
Python Sqlite3以字典形式返回查询结果的实现方法
Oct 03 Python
python字符串中的单双引
Feb 16 Python
Python实现一个转存纯真IP数据库的脚本分享
May 21 Python
Python2和Python3中print的用法示例总结
Oct 25 Python
python 实现A*算法的示例代码
Aug 13 Python
Python打开文件,将list、numpy数组内容写入txt文件中的方法
Oct 26 Python
Python中最好用的命令行参数解析工具(argparse)
Aug 23 Python
Python从入门到精通之环境搭建教程图解
Sep 26 Python
python中可以声明变量类型吗
Jun 18 Python
python 可视化库PyG2Plot的使用
Jan 21 Python
教你用python实现一个无界面的小型图书管理系统
May 21 Python
Python基于PyGraphics包实现图片截取功能的方法
Dec 21 #Python
用Python写王者荣耀刷金币脚本
Dec 21 #Python
python使用Apriori算法进行关联性解析
Dec 21 #Python
python实现kMeans算法
Dec 21 #Python
利用Tkinter(python3.6)实现一个简单计算器
Dec 21 #Python
python编写朴素贝叶斯用于文本分类
Dec 21 #Python
python并发2之使用asyncio处理并发
Dec 21 #Python
You might like
利用PHP创建动态图像
2006/10/09 PHP
初识php MVC
2014/09/10 PHP
php中chdir()函数用法实例
2014/11/13 PHP
ThinkPHP控制器里javascript代码不能执行的解决方法
2014/11/22 PHP
CentOS7.0下安装PHP5.6.30服务的教程详解
2018/09/29 PHP
Prototype使用指南之selector.js
2007/01/10 Javascript
仿新浪微博返回顶部的jquery实现代码
2012/10/01 Javascript
浅谈JavaScript中null和undefined
2015/07/09 Javascript
Jquery AJAX POST与GET之间的区别详细介绍
2016/10/17 Javascript
最常见的左侧分类菜单栏jQuery实现代码
2016/11/28 Javascript
JavaScript实现事件的中断传播和行为阻止方法示例
2017/01/20 Javascript
jQuery基于Ajax方式提交表单功能示例
2017/02/10 Javascript
JavaScript模块化之使用requireJS按需加载
2017/04/12 Javascript
jsonp跨域获取数据的基础教程
2018/07/01 Javascript
angular ng-model 无法获取值的处理方法
2018/10/02 Javascript
JS选取DOM元素常见操作方法实例分析
2018/12/10 Javascript
如何使用VuePress搭建一个类型element ui文档
2019/02/14 Javascript
Jquery实现获取子元素的方法分析
2019/08/24 jQuery
js实现无缝轮播图插件封装
2020/07/31 Javascript
解决ant Design中Select设置initialValue时的大坑
2020/10/29 Javascript
express异步函数异常捕获示例详解
2020/11/30 Javascript
[53:13]DOTA2-DPC中国联赛 正赛 DLG vs PHOENIX BO3 第三场 1月18日
2021/03/11 DOTA
Python生成随机数的方法
2014/01/14 Python
python运行其他程序的实现方法
2017/07/14 Python
Python semaphore evevt生产者消费者模型原理解析
2020/03/18 Python
Python的历史与优缺点整理
2020/05/26 Python
Pytorch之Tensor和Numpy之间的转换的实现方法
2020/09/03 Python
Python爬取酷狗MP3音频的步骤
2021/02/26 Python
野兽派官方旗舰店:THE BEAST 野兽派
2016/08/05 全球购物
英国健身专家:WIT Fitness
2021/02/09 全球购物
自我鉴定的范文
2013/10/03 职场文书
公务员职务工作的自我评价
2013/11/01 职场文书
社区中秋节活动方案
2014/01/29 职场文书
2014年群众路线党员自我评议
2014/09/24 职场文书
SQLServer 日期函数大全(小结)
2021/04/08 SQL Server
JMeter对MySQL数据库进行压力测试的实现步骤
2022/01/22 MySQL