python实现决策树分类


Posted in Python onAugust 30, 2018

上一篇博客主要介绍了决策树的原理,这篇主要介绍他的实现,代码环境python 3.4,实现的是ID3算法,首先为了后面matplotlib的绘图方便,我把原来的中文数据集变成了英文。

原始数据集:

python实现决策树分类

变化后的数据集在程序代码中体现,这就不截图了

构建决策树的代码如下:

#coding :utf-8
'''
2017.6.25 author :Erin 
   function: "decesion tree" ID3
   
'''
import numpy as np
import pandas as pd
from math import log
import operator 
def load_data():
 
 #data=np.array(data)
 data=[['teenager' ,'high', 'no' ,'same', 'no'],
   ['teenager', 'high', 'no', 'good', 'no'],
   ['middle_aged' ,'high', 'no', 'same', 'yes'],
   ['old_aged', 'middle', 'no' ,'same', 'yes'],
   ['old_aged', 'low', 'yes', 'same' ,'yes'],
   ['old_aged', 'low', 'yes', 'good', 'no'],
   ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
   ['teenager' ,'middle' ,'no', 'same', 'no'],
   ['teenager', 'low' ,'yes' ,'same', 'yes'],
   ['old_aged' ,'middle', 'yes', 'same', 'yes'],
   ['teenager' ,'middle', 'yes', 'good', 'yes'],
   ['middle_aged' ,'middle', 'no', 'good', 'yes'],
   ['middle_aged', 'high', 'yes', 'same', 'yes'],
   ['old_aged', 'middle', 'no' ,'good' ,'no']]
 features=['age','input','student','level']
 return data,features
 
def cal_entropy(dataSet):
 '''
 输入data ,表示带最后标签列的数据集
 计算给定数据集总的信息熵
 {'是': 9, '否': 5}
 0.9402859586706309
 '''
 
 numEntries = len(dataSet)
 labelCounts = {}
 for featVec in dataSet:
  label = featVec[-1]
  if label not in labelCounts.keys():
   labelCounts[label] = 0
  labelCounts[label] += 1
 entropy = 0.0
 for key in labelCounts.keys():
  p_i = float(labelCounts[key]/numEntries)
  entropy -= p_i * log(p_i,2)#log(x,10)表示以10 为底的对数
 return entropy
 
def split_data(data,feature_index,value):
 '''
 划分数据集
 feature_index:用于划分特征的列数,例如“年龄”
 value:划分后的属性值:例如“青少年”
 '''
 data_split=[]#划分后的数据集
 for feature in data:
  if feature[feature_index]==value:
   reFeature=feature[:feature_index]
   reFeature.extend(feature[feature_index+1:])
   data_split.append(reFeature)
 return data_split
def choose_best_to_split(data):
 
 '''
 根据每个特征的信息增益,选择最大的划分数据集的索引特征
 '''
 
 count_feature=len(data[0])-1#特征个数4
 #print(count_feature)#4
 entropy=cal_entropy(data)#原数据总的信息熵
 #print(entropy)#0.9402859586706309
 
 max_info_gain=0.0#信息增益最大
 split_fea_index = -1#信息增益最大,对应的索引号
 
 for i in range(count_feature):
  
  feature_list=[fe_index[i] for fe_index in data]#获取该列所有特征值
  #######################################
  '''
  print('feature_list')
  ['青少年', '青少年', '中年', '老年', '老年', '老年', '中年', '青少年', '青少年', '老年',
  '青少年', '中年', '中年', '老年']
  0.3467680694480959 #对应上篇博客中的公式 =(1)*5/14
  0.3467680694480959
  0.6935361388961918
  '''
  # print(feature_list)
  unqval=set(feature_list)#去除重复
  Pro_entropy=0.0#特征的熵
  for value in unqval:#遍历改特征下的所有属性
   sub_data=split_data(data,i,value)
   pro=len(sub_data)/float(len(data))
   Pro_entropy+=pro*cal_entropy(sub_data)
   #print(Pro_entropy)
   
  info_gain=entropy-Pro_entropy
  if(info_gain>max_info_gain):
   max_info_gain=info_gain
   split_fea_index=i
 return split_fea_index
  
  
##################################################
def most_occur_label(labels):
 #sorted_label_count[0][0] 次数最多的类标签
 label_count={}
 for label in labels:
  if label not in label_count.keys():
   label_count[label]=0
  else:
   label_count[label]+=1
  sorted_label_count = sorted(label_count.items(),key = operator.itemgetter(1),reverse = True)
 return sorted_label_count[0][0]
def build_decesion_tree(dataSet,featnames):
 '''
 字典的键存放节点信息,分支及叶子节点存放值
 '''
 featname = featnames[:]    ################
 classlist = [featvec[-1] for featvec in dataSet] #此节点的分类情况
 if classlist.count(classlist[0]) == len(classlist): #全部属于一类
  return classlist[0]
 if len(dataSet[0]) == 1:   #分完了,没有属性了
  return Vote(classlist)  #少数服从多数
 # 选择一个最优特征进行划分
 bestFeat = choose_best_to_split(dataSet)
 bestFeatname = featname[bestFeat]
 del(featname[bestFeat])  #防止下标不准
 DecisionTree = {bestFeatname:{}}
 # 创建分支,先找出所有属性值,即分支数
 allvalue = [vec[bestFeat] for vec in dataSet]
 specvalue = sorted(list(set(allvalue))) #使有一定顺序
 for v in specvalue:
  copyfeatname = featname[:]
  DecisionTree[bestFeatname][v] = build_decesion_tree(split_data(dataSet,bestFeat,v),copyfeatname)
 return DecisionTree

绘制可视化图的代码如下:

def getNumLeafs(myTree):
 '计算决策树的叶子数'
 
 # 叶子数
 numLeafs = 0
 # 节点信息
 sides = list(myTree.keys()) 
 firstStr =sides[0]
 # 分支信息
 secondDict = myTree[firstStr]
 
 for key in secondDict.keys(): # 遍历所有分支
  # 子树分支则递归计算
  if type(secondDict[key]).__name__=='dict':
   numLeafs += getNumLeafs(secondDict[key])
  # 叶子分支则叶子数+1
  else: numLeafs +=1
  
 return numLeafs
 
 
def getTreeDepth(myTree):
 '计算决策树的深度'
 
 # 最大深度
 maxDepth = 0
 # 节点信息
 sides = list(myTree.keys()) 
 firstStr =sides[0]
 # 分支信息
 secondDict = myTree[firstStr]
 
 for key in secondDict.keys(): # 遍历所有分支
  # 子树分支则递归计算
  if type(secondDict[key]).__name__=='dict':
   thisDepth = 1 + getTreeDepth(secondDict[key])
  # 叶子分支则叶子数+1
  else: thisDepth = 1
  
  # 更新最大深度
  if thisDepth > maxDepth: maxDepth = thisDepth
  
 return maxDepth
 
import matplotlib.pyplot as plt
 
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
 
# ==================================================
# 输入:
#  nodeTxt:  终端节点显示内容
#  centerPt: 终端节点坐标
#  parentPt: 起始节点坐标
#  nodeType: 终端节点样式
# 输出:
#  在图形界面中显示输入参数指定样式的线段(终端带节点)
# ==================================================
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
 '画线(末端带一个点)'
  
 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
 
# =================================================================
# 输入:
#  cntrPt:  终端节点坐标
#  parentPt: 起始节点坐标
#  txtString: 待显示文本内容
# 输出:
#  在图形界面指定位置(cntrPt和parentPt中间)显示文本内容(txtString)
# =================================================================
def plotMidText(cntrPt, parentPt, txtString):
 '在指定位置添加文本'
 
 # 中间位置坐标
 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
 
 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
 
# ===================================
# 输入:
#  myTree: 决策树
#  parentPt: 根节点坐标
#  nodeTxt: 根节点坐标信息
# 输出:
#  在图形界面绘制决策树
# ===================================
def plotTree(myTree, parentPt, nodeTxt):
 '绘制决策树'
 
 # 当前树的叶子数
 numLeafs = getNumLeafs(myTree)
 # 当前树的节点信息
 sides = list(myTree.keys()) 
 firstStr =sides[0]
 
 # 定位第一棵子树的位置(这是蛋疼的一部分)
 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
 
 # 绘制当前节点到子树节点(含子树节点)的信息
 plotMidText(cntrPt, parentPt, nodeTxt)
 plotNode(firstStr, cntrPt, parentPt, decisionNode)
 
 # 获取子树信息
 secondDict = myTree[firstStr]
 # 开始绘制子树,纵坐标-1。  
 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  
 for key in secondDict.keys(): # 遍历所有分支
  # 子树分支则递归
  if type(secondDict[key]).__name__=='dict':
   plotTree(secondDict[key],cntrPt,str(key))
  # 叶子分支则直接绘制
  else:
   plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
   plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
   plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  
 # 子树绘制完毕,纵坐标+1。
 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
 
# ==============================
# 输入:
#  myTree: 决策树
# 输出:
#  在图形界面显示决策树
# ==============================
def createPlot(inTree):
 '显示决策树'
 
 # 创建新的图像并清空 - 无横纵坐标
 fig = plt.figure(1, facecolor='white')
 fig.clf()
 axprops = dict(xticks=[], yticks=[])
 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
 
 # 树的总宽度 高度
 plotTree.totalW = float(getNumLeafs(inTree))
 plotTree.totalD = float(getTreeDepth(inTree))
 
 # 当前绘制节点的坐标
 plotTree.xOff = -0.5/plotTree.totalW; 
 plotTree.yOff = 1.0;
 
 # 绘制决策树
 plotTree(inTree, (0.5,1.0), '')
 
 plt.show()
 
if __name__ == '__main__':
 data,features=load_data()
 split_fea_index=choose_best_to_split(data)
 newtree=build_decesion_tree(data,features)
 print(newtree)
 createPlot(newtree)
 '''
 {'age': {'old_aged': {'level': {'same': 'yes', 'good': 'no'}}, 'teenager': {'student': {'no': 'no', 'yes': 'yes'}}, 'middle_aged': 'yes'}}
 '''

结果如下:

python实现决策树分类

怎么用决策树分类,将会在下一章。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python语言编写电脑时间自动同步小工具
Mar 08 Python
简单的编程0基础下Python入门指引
Apr 01 Python
20招让你的Python飞起来!
Sep 27 Python
python制作爬虫爬取京东商品评论教程
Dec 16 Python
Python多进程与服务器并发原理及用法实例分析
Aug 21 Python
python使用opencv对图像mask处理的方法
Jul 05 Python
Python测试模块doctest使用解析
Aug 10 Python
使用matlab或python将txt文件转为excel表格
Nov 01 Python
python实现将一维列表转换为多维列表(numpy+reshape)
Nov 29 Python
Python中的X[:,0]、X[:,1]、X[:,:,0]、X[:,:,1]、X[:,m:n]和X[:,:,m:n]
Feb 13 Python
社区版pycharm创建django项目的方法(pycharm的newproject左侧没有项目选项)
Sep 23 Python
Python 虚拟环境工作原理解析
Dec 24 Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
python实现录音小程序
Oct 26 #Python
Python图像处理之简单画板实现方法示例
Aug 30 #Python
You might like
推荐一篇入门级的Class文章
2007/03/19 PHP
php中explode的负数limit用法分析
2015/02/27 PHP
PHP实现执行外部程序的方法详解
2017/08/17 PHP
PHP基于双向链表与排序操作实现的会员排名功能示例
2017/12/26 PHP
安装docker和docker-compose实例详解
2019/07/30 PHP
JS判断是否为数字,是否为整数,是否为浮点数的代码
2010/04/24 Javascript
jQuery JSON的解析方式分享
2011/04/05 Javascript
根据经纬度计算地球上两点之间的距离js实现代码
2013/03/05 Javascript
九种js弹出对话框的方法总结
2013/03/12 Javascript
Javascript之this关键字深入解析
2013/11/12 Javascript
jQuery照片伸缩效果不影响其他元素的布局
2014/05/09 Javascript
javascript引擎长时间独占线程造成卡顿的解决方案
2014/12/03 Javascript
使用jQuery获取data-的自定义属性
2015/11/10 Javascript
jquery.validate提示错误信息位置方法
2016/01/22 Javascript
Bootstrap Navbar Component实现响应式导航
2016/10/08 Javascript
使用微信小程序开发前端【快速入门】
2016/12/05 Javascript
JS变量中有var定义和无var定义的区别以及es6中let命令和const命令
2017/02/19 Javascript
js使用highlight.js高亮你的代码
2017/08/18 Javascript
jQuery UI Draggable + Sortable 结合使用(实例讲解)
2017/09/07 jQuery
关于Google发布的JavaScript代码规范你要知道哪些
2018/04/04 Javascript
Node 搭建一个静态资源服务器的实现
2019/05/20 Javascript
layui 根据后台数据动态创建下拉框并同时默认选中的实例
2019/09/02 Javascript
Python导入txt数据到mysql的方法
2015/04/08 Python
Python验证码识别处理实例
2015/12/28 Python
python实现二叉查找树实例代码
2018/02/08 Python
python清除字符串中间空格的实例讲解
2018/05/11 Python
HTML5等待加载动画效果
2017/07/27 HTML / CSS
马耳他航空公司官方网站:Air Malta
2019/05/15 全球购物
英国领先的酒杯和水晶玻璃器皿制造商:Dartington Crystal
2019/06/23 全球购物
高二地理教学反思
2014/01/24 职场文书
高中毕业生登记表自我鉴定范文
2014/03/18 职场文书
2014年工程师工作总结
2014/11/25 职场文书
小升初自荐信范文
2015/03/05 职场文书
2015年语文教学工作总结
2015/05/25 职场文书
python四个坐标点对图片区域最小外接矩形进行裁剪
2021/06/04 Python
SQL实现LeetCode(176.第二高薪水)
2021/08/04 MySQL