python实现基于信息增益的决策树归纳


Posted in Python onDecember 18, 2018

本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy
 
#加载训练数据
#文件格式:属性标号,是否连续【yes|no】,属性说明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)
 
#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)
 
#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)
 
 
root_attr_dict = {}
for line in attribute_file :
  line = line.strip()
  fld_list = line.split(',')
  root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])
 
class_dict = {}
for line in class_desc_file :
  line = line.strip()
  fld_list = line.split(',')
  class_dict[int(fld_list[0])] = fld_list[1]
  
trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  
  if c_id not in class_member_set_dict :
    class_member_set_dict[c_id] = set()
  class_member_set_dict[c_id].add(rec_id)
  trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)
  
attribute_file.close()
class_desc_file.close()
trainning_data_file.close()
 
class_possibility_dict = {}
for c_id in class_member_set_dict :
  class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)  
 
#等待分类的数据
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()
 
 
 
 
'''
决策树的表达
结点的需求:
1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点
2、保存分类所需信息
3、子结点列表
每个结点用Tuple类型表示
元素一是整形,取值123 分别对应两种分裂类型
元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号
元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点
对于3来说1代表属于判别集合
'''
 
  
#对于一个成员列表,计算其熵
#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和
def get_entropy( member_list ) :
  #成员总数
  mem_cnt = len(member_list)
  #首先找出member中所包含的分类
  class_dict = {}
  for mem_id in member_list :
    c_id = trainning_data_dict[mem_id][3]
    if c_id not in class_dict :
      class_dict[c_id] = set()
    class_dict[c_id].add(mem_id)
  
  tmp_sum = 0.0
  for c_id in class_dict :
    pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
    tmp_sum += pi * mlab.log2(pi)
  tmp_sum = -tmp_sum
  return tmp_sum
    
 
def attribute_selection_method( member_list , attribute_dict ) :
  #先计算原始的熵
  info_D = get_entropy(member_list)
  
  max_info_Gain = 0.0
  attr_get = 0
  split_point = 0.0
  for attr_id in attribute_dict :
    #对于每一个属性计算划分后的熵
    #信息增益等于原始的熵减去划分后的熵
    info_D_new = 0
    #如果是连续属性
    if attribute_dict[attr_id][0] == 'yes' :
      #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵
      #找出其中最小的,作为此连续属性的划分点
      value_list = []
      for mem_id in member_list :
        value_list.append(trainning_data_dict[mem_id][attr_id - 1])
      
      #获取相邻元素的中值序列
      mid_value_list = []
      value_list.sort()
      #print value_list
      last_value = None
      for value in value_list :
        if value == last_value :
          continue
        if last_value is not None :
          mid_value_list.append((last_value+value)/2)
        last_value = value
      #print mid_value_list
      #对于中值序列做循环
      #计算以此值做为划分点的熵
      #总的熵等于两个划分的熵乘以两个划分的比重
      min_info = 1000000000.0
      total_mens = len(member_list) + 0.0
      for mid_value in mid_value_list :
        #小于mid_value的mem
        less_list = []
        #大于
        more_list = []
        for tmp_mem_id in member_list :
          if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
            less_list.append(tmp_mem_id)
          else :
            more_list.append(tmp_mem_id)
        sum_info = len(less_list)/total_mens * get_entropy(less_list) \
        + len(more_list)/total_mens * get_entropy(more_list)
        
        if sum_info < min_info :
          min_info = sum_info
          split_point = mid_value
          
      info_D_new = min_info
    #如果是离散属性
    else :
      #计算划分后的熵
      #采用循环累加的方式
      attr_value_member_dict = {} #键为attribute value , 值为memberlist
      for tmp_mem_id in member_list :
        attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
        if attr_value not in attr_value_member_dict :
          attr_value_member_dict[attr_value] = []
        attr_value_member_dict[attr_value].append(tmp_mem_id)
      #将每个离散值的熵乘以比重加到这上面
      total_mens = len(member_list) + 0.0
      sum_info = 0.0
      for a_value in attr_value_member_dict :
        sum_info += len(attr_value_member_dict[a_value])/total_mens \
        * get_entropy(attr_value_member_dict[a_value])
      
      info_D_new = sum_info
    
    info_Gain = info_D - info_D_new
    if info_Gain > max_info_Gain :
      max_info_Gain = info_Gain
      attr_get = attr_id
  
  #如果是离散的
  #print 'attr_get ' + str(attr_get)
  if attribute_dict[attr_get][0] == 'no' :
    return (1 , attr_get , split_point)
  else :  
    return (2 , attr_get , split_point)
  #第三类先不考虑
 
def get_decision_tree(father_node , key , member_list , attr_dict ) :
  #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键
  #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号
  class_set = set()
  for mem_id in member_list :
    class_set.add(trainning_data_dict[mem_id][3])
  if len(class_set) == 1 :
    father_node[2][key] = (0 , (1 , class_set) , {} )
    return
  
  #检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号
  #如果几个类的成员等量,则打印提示,并且全部添加到set里面
  if not attr_dict :
    class_cnt_dict = {}
    for mem_id in member_list :
      c_id = trainning_data_dict[mem_id][3]
      if c_id not in class_cnt_dict :
        class_cnt_dict[c_id] = 1
      else :
        class_cnt_dict[c_id] += 1
        
    class_set = set()
    max_cnt = 0
    for c_id in class_cnt_dict :
      if class_cnt_dict[c_id] > max_cnt :
        max_cnt = class_cnt_dict[c_id]
        class_set.clear()
        class_set.add(c_id)
      elif class_cnt_dict[c_id] == max_cnt :
        class_set.add(c_id)
    
    if len(class_set) > 1 :
      print 'more than one class !'
    
    father_node[2][key] = (0 , (1 , class_set ) , {} )
    return
  
  #找出最好的分区方案 , 暂不考虑第三种划分方法
  #比较所有离散属性和所有连续属性的所有中值点划分的信息增益
  split_criterion = attribute_selection_method(member_list , attr_dict)
  #print split_criterion
  selected_plan_id = split_criterion[0]
  selected_attr_id = split_criterion[1]
  
  #如果采用的是离散属性做为分区方案,删除这个属性
  new_attr_dict = copy(attr_dict)
  if attr_dict[selected_attr_id][0] == 'no' :
    del new_attr_dict[selected_attr_id]
  
  #建立一个结点new_node,father_node[2][key] = new_node
  #然后对new node的每一个key , sub_member_list,
  #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
  #实现递归
  ele2 = ( selected_attr_id , set() )
  #如果是1 , ele2保存所有离散值
  if selected_plan_id == 1 :
    for mem_id in member_list :
      ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
  #如果是2,ele2保存分裂点
  elif selected_plan_id == 2 :
    ele2[1].add(split_criterion[2])
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
    
  new_node = ( selected_plan_id , ele2 , {} )
  father_node[2][key] = new_node
  
  #生成KEY,并递归调用
  if selected_plan_id == 1 :
    #每个attr_value是一个key
    attr_value_member_dict = {}
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value not in attr_value_member_dict :
        attr_value_member_dict[attr_value] = []
      attr_value_member_dict[attr_value].append(mem_id)
    for attr_value in attr_value_member_dict :
      get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
    pass
  elif selected_plan_id == 2 :
    #key 只有12 , 小于等于分裂点的是1 , 大于的是2
    less_list = []
    more_list = []
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value <= split_criterion[2] :
        less_list.append(mem_id)
      else :
        more_list.append(mem_id)
    #if len(less_list) != 0 :
    get_decision_tree(new_node , 1 , less_list , new_attr_dict)
    #if len(more_list) != 0 :
    get_decision_tree(new_node , 2 , more_list , new_attr_dict)
    pass
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
  
def get_class_sub(node , tp ) :
  #
  attr_id = node[1][0]
  plan_id = node[0]
  key = 0
  if plan_id == 0 :
    return node[1][1]
  elif plan_id == 1 :
    key = tp[attr_id - 1]
  elif plan_id == 2 :
    split_point = tuple(node[1][1])[0]
    attr_value = tp[attr_id - 1]
    if attr_value <= split_point :
      key = 1
    else :
      key = 2
  else :
    print 'error'
    return set()
    
  return get_class_sub(node[2][key] , tp )
 
def get_class(r_node , tp) :
  #tp为一组属性值
  if r_node[0] != -1 :
    print 'error'
    return set()
  
  if 1 in r_node[2] :
    return get_class_sub(r_node[2][1] , tp)
  else :
    print 'error'
    return set()
  
  
if __name__ == '__main__' :
  root_node = ( -1 , set() , {} )
  mem_list = trainning_data_dict.keys()
  get_decision_tree(root_node , 1 , mem_list , root_attr_dict )
 
  #测试分类器的准确率
  diff_cnt = 0
  for mem_id in data_to_classify_dict :
    c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
    if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
      print tuple(c_id)[0]
      print data_to_classify_dict[mem_id][3]
      print 'different'
      diff_cnt += 1
  print diff_cnt

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ORM框架SQLAlchemy学习笔记之映射类使用实例和Session会话介绍
Jun 10 Python
python类装饰器用法实例
Jun 04 Python
Python读大数据txt
Mar 28 Python
Python实现单词翻译功能
Jun 06 Python
python深度优先搜索和广度优先搜索
Feb 07 Python
python GUI库图形界面开发之PyQt5日期时间控件QDateTimeEdit详细使用方法与实例
Feb 27 Python
python上传时包含boundary时的解决方法
Apr 08 Python
30行Python代码实现高分辨率图像导航的方法
May 22 Python
python学习将数据写入文件并保存方法
Jun 07 Python
python批量生成条形码的示例
Oct 10 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
Jan 27 Python
Python操作CSV格式文件的方法大全
Jul 15 Python
Django实现一对多表模型的跨表查询方法
Dec 18 #Python
Python实现字典排序、按照list中字典的某个key排序的方法示例
Dec 18 #Python
python实现求特征选择的信息增益
Dec 18 #Python
python实现连续图文识别
Dec 18 #Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 #Python
Python列表list排列组合操作示例
Dec 18 #Python
python实现二维插值的三维显示
Dec 17 #Python
You might like
php 应用程序安全防范技术研究
2009/09/25 PHP
纯php打造的tab选项卡效果代码(不用js)
2010/12/29 PHP
显示程序执行时间php函数代码
2013/08/29 PHP
php实现12306余票查询、价格查询示例
2014/04/17 PHP
php一个解析字符串排列数组的方法
2015/05/12 PHP
PHP单例模式定义与使用实例详解
2017/02/06 PHP
Yii2框架中日志的使用方法分析
2017/05/22 PHP
php求斐波那契数的两种实现方式【递归与递推】
2019/09/09 PHP
用Javascript 获取页面元素的位置的代码
2009/09/25 Javascript
基于jQuery的的一个隔行变色,鼠标移动变色的小插件
2010/07/06 Javascript
js中关于new Object时传参的一些细节分析
2011/03/13 Javascript
jquery cookie的用法总结
2013/11/18 Javascript
JavaScript中的small()方法使用详解
2015/06/08 Javascript
jquery实现Slide Out Navigation滑出式菜单效果代码
2015/09/07 Javascript
JS实现很实用的对联广告代码(可自适应高度)
2015/09/18 Javascript
30分钟快速掌握Bootstrap框架
2016/05/24 Javascript
Bootstrap Table使用整理(三)
2017/06/09 Javascript
微信小程序 上传头像的实例详解
2017/10/27 Javascript
Bootstrap 模态框多次显示后台提交多次BUG的解决方法
2017/12/26 Javascript
JS实现访问DOM对象指定节点的方法示例
2018/04/04 Javascript
vue.js图片转Base64上传图片并预览的实现方法
2018/08/02 Javascript
支付宝小程序自定义弹窗dialog插件的实现代码
2018/11/30 Javascript
angular学习之动态创建表单的方法
2018/12/07 Javascript
Python读写txt文本文件的操作方法全解析
2016/06/26 Python
Python 出现错误TypeError: ‘NoneType’ object is not iterable解决办法
2017/01/12 Python
Linux下多个Python版本安装教程
2018/08/15 Python
python运用pygame库实现双人弹球小游戏
2019/11/25 Python
python GUI库图形界面开发之PyQt5窗口类QMainWindow详细使用方法
2020/02/26 Python
css3发光搜索表单分享
2014/04/11 HTML / CSS
基于html5绘制圆形多角图案
2016/04/21 HTML / CSS
美国在线眼镜商城:Eyeglasses.com
2017/06/26 全球购物
澳大利亚领先的宠物用品商店:VetSupply
2017/09/08 全球购物
学校搬迁方案
2014/06/15 职场文书
律师授权委托书范本
2014/10/07 职场文书
销售经理工作检讨书
2015/02/19 职场文书
Go结合Gin导出Mysql数据到Excel表格
2022/08/05 Golang