python实现基于信息增益的决策树归纳


Posted in Python onDecember 18, 2018

本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy
 
#加载训练数据
#文件格式:属性标号,是否连续【yes|no】,属性说明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)
 
#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)
 
#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)
 
 
root_attr_dict = {}
for line in attribute_file :
  line = line.strip()
  fld_list = line.split(',')
  root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])
 
class_dict = {}
for line in class_desc_file :
  line = line.strip()
  fld_list = line.split(',')
  class_dict[int(fld_list[0])] = fld_list[1]
  
trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  
  if c_id not in class_member_set_dict :
    class_member_set_dict[c_id] = set()
  class_member_set_dict[c_id].add(rec_id)
  trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)
  
attribute_file.close()
class_desc_file.close()
trainning_data_file.close()
 
class_possibility_dict = {}
for c_id in class_member_set_dict :
  class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)  
 
#等待分类的数据
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()
 
 
 
 
'''
决策树的表达
结点的需求:
1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点
2、保存分类所需信息
3、子结点列表
每个结点用Tuple类型表示
元素一是整形,取值123 分别对应两种分裂类型
元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号
元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点
对于3来说1代表属于判别集合
'''
 
  
#对于一个成员列表,计算其熵
#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和
def get_entropy( member_list ) :
  #成员总数
  mem_cnt = len(member_list)
  #首先找出member中所包含的分类
  class_dict = {}
  for mem_id in member_list :
    c_id = trainning_data_dict[mem_id][3]
    if c_id not in class_dict :
      class_dict[c_id] = set()
    class_dict[c_id].add(mem_id)
  
  tmp_sum = 0.0
  for c_id in class_dict :
    pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
    tmp_sum += pi * mlab.log2(pi)
  tmp_sum = -tmp_sum
  return tmp_sum
    
 
def attribute_selection_method( member_list , attribute_dict ) :
  #先计算原始的熵
  info_D = get_entropy(member_list)
  
  max_info_Gain = 0.0
  attr_get = 0
  split_point = 0.0
  for attr_id in attribute_dict :
    #对于每一个属性计算划分后的熵
    #信息增益等于原始的熵减去划分后的熵
    info_D_new = 0
    #如果是连续属性
    if attribute_dict[attr_id][0] == 'yes' :
      #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵
      #找出其中最小的,作为此连续属性的划分点
      value_list = []
      for mem_id in member_list :
        value_list.append(trainning_data_dict[mem_id][attr_id - 1])
      
      #获取相邻元素的中值序列
      mid_value_list = []
      value_list.sort()
      #print value_list
      last_value = None
      for value in value_list :
        if value == last_value :
          continue
        if last_value is not None :
          mid_value_list.append((last_value+value)/2)
        last_value = value
      #print mid_value_list
      #对于中值序列做循环
      #计算以此值做为划分点的熵
      #总的熵等于两个划分的熵乘以两个划分的比重
      min_info = 1000000000.0
      total_mens = len(member_list) + 0.0
      for mid_value in mid_value_list :
        #小于mid_value的mem
        less_list = []
        #大于
        more_list = []
        for tmp_mem_id in member_list :
          if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
            less_list.append(tmp_mem_id)
          else :
            more_list.append(tmp_mem_id)
        sum_info = len(less_list)/total_mens * get_entropy(less_list) \
        + len(more_list)/total_mens * get_entropy(more_list)
        
        if sum_info < min_info :
          min_info = sum_info
          split_point = mid_value
          
      info_D_new = min_info
    #如果是离散属性
    else :
      #计算划分后的熵
      #采用循环累加的方式
      attr_value_member_dict = {} #键为attribute value , 值为memberlist
      for tmp_mem_id in member_list :
        attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
        if attr_value not in attr_value_member_dict :
          attr_value_member_dict[attr_value] = []
        attr_value_member_dict[attr_value].append(tmp_mem_id)
      #将每个离散值的熵乘以比重加到这上面
      total_mens = len(member_list) + 0.0
      sum_info = 0.0
      for a_value in attr_value_member_dict :
        sum_info += len(attr_value_member_dict[a_value])/total_mens \
        * get_entropy(attr_value_member_dict[a_value])
      
      info_D_new = sum_info
    
    info_Gain = info_D - info_D_new
    if info_Gain > max_info_Gain :
      max_info_Gain = info_Gain
      attr_get = attr_id
  
  #如果是离散的
  #print 'attr_get ' + str(attr_get)
  if attribute_dict[attr_get][0] == 'no' :
    return (1 , attr_get , split_point)
  else :  
    return (2 , attr_get , split_point)
  #第三类先不考虑
 
def get_decision_tree(father_node , key , member_list , attr_dict ) :
  #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键
  #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号
  class_set = set()
  for mem_id in member_list :
    class_set.add(trainning_data_dict[mem_id][3])
  if len(class_set) == 1 :
    father_node[2][key] = (0 , (1 , class_set) , {} )
    return
  
  #检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号
  #如果几个类的成员等量,则打印提示,并且全部添加到set里面
  if not attr_dict :
    class_cnt_dict = {}
    for mem_id in member_list :
      c_id = trainning_data_dict[mem_id][3]
      if c_id not in class_cnt_dict :
        class_cnt_dict[c_id] = 1
      else :
        class_cnt_dict[c_id] += 1
        
    class_set = set()
    max_cnt = 0
    for c_id in class_cnt_dict :
      if class_cnt_dict[c_id] > max_cnt :
        max_cnt = class_cnt_dict[c_id]
        class_set.clear()
        class_set.add(c_id)
      elif class_cnt_dict[c_id] == max_cnt :
        class_set.add(c_id)
    
    if len(class_set) > 1 :
      print 'more than one class !'
    
    father_node[2][key] = (0 , (1 , class_set ) , {} )
    return
  
  #找出最好的分区方案 , 暂不考虑第三种划分方法
  #比较所有离散属性和所有连续属性的所有中值点划分的信息增益
  split_criterion = attribute_selection_method(member_list , attr_dict)
  #print split_criterion
  selected_plan_id = split_criterion[0]
  selected_attr_id = split_criterion[1]
  
  #如果采用的是离散属性做为分区方案,删除这个属性
  new_attr_dict = copy(attr_dict)
  if attr_dict[selected_attr_id][0] == 'no' :
    del new_attr_dict[selected_attr_id]
  
  #建立一个结点new_node,father_node[2][key] = new_node
  #然后对new node的每一个key , sub_member_list,
  #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
  #实现递归
  ele2 = ( selected_attr_id , set() )
  #如果是1 , ele2保存所有离散值
  if selected_plan_id == 1 :
    for mem_id in member_list :
      ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
  #如果是2,ele2保存分裂点
  elif selected_plan_id == 2 :
    ele2[1].add(split_criterion[2])
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
    
  new_node = ( selected_plan_id , ele2 , {} )
  father_node[2][key] = new_node
  
  #生成KEY,并递归调用
  if selected_plan_id == 1 :
    #每个attr_value是一个key
    attr_value_member_dict = {}
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value not in attr_value_member_dict :
        attr_value_member_dict[attr_value] = []
      attr_value_member_dict[attr_value].append(mem_id)
    for attr_value in attr_value_member_dict :
      get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
    pass
  elif selected_plan_id == 2 :
    #key 只有12 , 小于等于分裂点的是1 , 大于的是2
    less_list = []
    more_list = []
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value <= split_criterion[2] :
        less_list.append(mem_id)
      else :
        more_list.append(mem_id)
    #if len(less_list) != 0 :
    get_decision_tree(new_node , 1 , less_list , new_attr_dict)
    #if len(more_list) != 0 :
    get_decision_tree(new_node , 2 , more_list , new_attr_dict)
    pass
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
  
def get_class_sub(node , tp ) :
  #
  attr_id = node[1][0]
  plan_id = node[0]
  key = 0
  if plan_id == 0 :
    return node[1][1]
  elif plan_id == 1 :
    key = tp[attr_id - 1]
  elif plan_id == 2 :
    split_point = tuple(node[1][1])[0]
    attr_value = tp[attr_id - 1]
    if attr_value <= split_point :
      key = 1
    else :
      key = 2
  else :
    print 'error'
    return set()
    
  return get_class_sub(node[2][key] , tp )
 
def get_class(r_node , tp) :
  #tp为一组属性值
  if r_node[0] != -1 :
    print 'error'
    return set()
  
  if 1 in r_node[2] :
    return get_class_sub(r_node[2][1] , tp)
  else :
    print 'error'
    return set()
  
  
if __name__ == '__main__' :
  root_node = ( -1 , set() , {} )
  mem_list = trainning_data_dict.keys()
  get_decision_tree(root_node , 1 , mem_list , root_attr_dict )
 
  #测试分类器的准确率
  diff_cnt = 0
  for mem_id in data_to_classify_dict :
    c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
    if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
      print tuple(c_id)[0]
      print data_to_classify_dict[mem_id][3]
      print 'different'
      diff_cnt += 1
  print diff_cnt

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
web.py获取上传文件名的正确方法
Aug 26 Python
python 调用HBase的简单实例
Dec 18 Python
速记Python布尔值
Nov 09 Python
python实现将读入的多维list转为一维list的方法
Jun 28 Python
python中struct模块之字节型数据的处理方法
Aug 27 Python
Django实现简单网页弹出警告代码
Nov 15 Python
Python反爬虫伪装浏览器进行爬虫
Feb 28 Python
Django使用list对单个或者多个字段求values值实例
Mar 31 Python
使用python修改文件并立即写回到原始位置操作(inplace读写)
Jun 28 Python
对Python 字典元素进行删除的方法
Jul 31 Python
python跨文件使用全局变量的实现
Nov 17 Python
Python 实现进度条的六种方式
Jan 06 Python
Django实现一对多表模型的跨表查询方法
Dec 18 #Python
Python实现字典排序、按照list中字典的某个key排序的方法示例
Dec 18 #Python
python实现求特征选择的信息增益
Dec 18 #Python
python实现连续图文识别
Dec 18 #Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 #Python
Python列表list排列组合操作示例
Dec 18 #Python
python实现二维插值的三维显示
Dec 17 #Python
You might like
桌面中心(一)创建数据库
2006/10/09 PHP
PHP EOT定界符的使用详解
2008/09/30 PHP
PHP 图片文件上传实现代码
2010/12/29 PHP
TP(thinkPHP)框架多层控制器和多级控制器的使用示例
2018/06/13 PHP
Yii2.0 RESTful API 基础配置教程详解
2018/12/26 PHP
js关闭父窗口时关闭子窗口
2013/04/01 Javascript
JavaScript使用setInterval()函数实现简单轮询操作的方法
2015/02/02 Javascript
javascript实现列表滚动的方法
2015/07/30 Javascript
js制作带有遮罩弹出层实现登录注册表单特效代码分享
2015/09/05 Javascript
基于jQuery的Web上传插件Uploadify使用示例
2016/05/19 Javascript
详解Vue整合axios的实例代码
2017/06/21 Javascript
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
javascript高仿热血传奇游戏实现代码
2018/02/22 Javascript
jQuery实现百度图片移入移出内容提示框上下左右移动的效果
2018/06/05 jQuery
Node.js模拟发起http请求从异步转同步的5种用法
2018/09/26 Javascript
Vue form表单动态添加组件实战案例
2019/09/02 Javascript
Vue是怎么渲染template内的标签内容的
2020/06/05 Javascript
详解JavaScript 的执行机制
2020/09/18 Javascript
Python实现对比不同字体中的同一字符的显示效果
2015/04/23 Python
两个使用Python脚本操作文件的小示例分享
2015/08/27 Python
详解python函数传参是传值还是传引用
2018/01/16 Python
python3.4.3下逐行读入txt文本并去重的方法
2018/04/29 Python
使用Python监控文件内容变化代码实例
2018/06/04 Python
在Python 中实现图片加框和加字的方法
2019/01/26 Python
django框架model orM使用字典作为参数,保存数据的方法分析
2019/06/24 Python
Python的条件锁与事件共享详解
2019/09/12 Python
python 消费 kafka 数据教程
2019/12/21 Python
Jupyter notebook运行Spark+Scala教程
2020/04/10 Python
keras用auc做metrics以及早停实例
2020/07/02 Python
阿迪达斯俄罗斯官方商城:adidas俄罗斯
2017/03/08 全球购物
大学总结自我鉴定
2014/01/18 职场文书
电视节目策划方案
2014/05/16 职场文书
个人租房协议书样本
2014/10/01 职场文书
运动会广播稿50字-100字
2014/10/11 职场文书
新农村建设指导员工作总结
2015/08/13 职场文书
利用Apache Common将java对象池化的问题
2022/06/16 Servers