Python实现的随机森林算法与简单总结


Posted in Python onJanuary 30, 2018

本文实例讲述了Python实现的随机森林算法。分享给大家供大家参考,具体如下:

随机森林是数据挖掘中非常常用的分类预测算法,以分类或回归的决策树为基分类器。算法的一些基本要点:

*对大小为m的数据集进行样本量同样为m的有放回抽样;
*对K个特征进行随机抽样,形成特征的子集,样本量的确定方法可以有平方根、自然对数等;
*每棵树完全生成,不进行剪枝;
*每个样本的预测结果由每棵树的预测投票生成(回归的时候,即各棵树的叶节点的平均)

著名的python机器学习包scikit learn的文档对此算法有比较详尽的介绍: http://scikit-learn.org/stable/modules/ensemble.html#random-forests

出于个人研究和测试的目的,基于经典的Kaggle 101泰坦尼克号乘客的数据集,建立模型并进行评估。比赛页面及相关数据集的下载:https://www.kaggle.com/c/titanic

泰坦尼克号的沉没,是历史上非常著名的海难。突然感到,自己面对的不再是冷冰冰的数据,而是用数据挖掘的方法,去研究具体的历史问题,也是饶有兴趣。言归正传,模型的主要的目标,是希望根据每个乘客的一系列特征,如性别、年龄、舱位、上船地点等,对其是否能生还进行预测,是非常典型的二分类预测问题。数据集的字段名及实例如下:

PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
1 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.25 S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Thayer) female 38 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.925 S
4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35 1 0 113803 53.1 C123 S
5 0 3 Allen, Mr. William Henry male 35 0 0 373450 8.05 S

值得说明的是,SibSp是指sister brother spouse,即某个乘客随行的兄弟姐妹、丈夫、妻子的人数,Parch指parents,children

下面给出整个数据处理及建模过程,基于ubuntu+python 3.4( anaconda科学计算环境已经集成一系列常用包,pandas numpy sklearn等,这里强烈推荐)

懒得切换输入法,写的时候主要的注释都是英文,中文的注释是后来补充的:-)

# -*- coding: utf-8 -*-
"""
@author: kim
"""
from model import *#载入基分类器的代码
#ETL:same procedure to training set and test set
training=pd.read_csv('train.csv',index_col=0)
test=pd.read_csv('test.csv',index_col=0)
SexCode=pd.DataFrame([1,0],index=['female','male'],columns=['Sexcode']) #将性别转化为01
training=training.join(SexCode,how='left',on=training.Sex)
training=training.drop(['Name','Ticket','Embarked','Cabin','Sex'],axis=1)#删去几个不参与建模的变量,包括姓名、船票号,船舱号
test=test.join(SexCode,how='left',on=test.Sex)
test=test.drop(['Name','Ticket','Embarked','Cabin','Sex'],axis=1)
print('ETL IS DONE!')
#MODEL FITTING
#===============PARAMETER AJUSTMENT============
min_leaf=1
min_dec_gini=0.0001
n_trees=5
n_fea=int(math.sqrt(len(training.columns)-1))
#==============================================
'''''
BEST SCORE:0.83
min_leaf=30
min_dec_gini=0.001
n_trees=20
'''
#ESSEMBLE BY RANDOM FOREST
FOREST={}
tmp=list(training.columns)
tmp.pop(tmp.index('Survived'))
feaList=pd.Series(tmp)
for t in range(n_trees):
#  fea=[]
  feasample=feaList.sample(n=n_fea,replace=False)#select feature
  fea=feasample.tolist()
  fea.append('Survived')
#    feaNew=fea.append(target)
  subset=training.sample(n=len(training),replace=True)#generate the dataset with replacement
  subset=subset[fea]
#  print(str(t)+' Classifier built on feature:')
#  print(list(fea))
  FOREST[t]=tree_grow(subset,'Survived',min_leaf,min_dec_gini) #save the tree
#MODEL PREDICTION
#======================
currentdata=training
output='submission_rf_20151116_30_0.001_20'
#======================
prediction={}
for r in currentdata.index:#a row
  prediction_vote={1:0,0:0}
  row=currentdata.get(currentdata.index==r)
  for n in range(n_trees):
    tree_dict=FOREST[n] #a tree
    p=model_prediction(tree_dict,row)
    prediction_vote[p]+=1
  vote=pd.Series(prediction_vote)
  prediction[r]=list(vote.order(ascending=False).index)[0]#the vote result
result=pd.Series(prediction,name='Survived_p')
#del prediction_vote
#del prediction
#result.to_csv(output)
t=training.join(result,how='left')
accuracy=round(len(t[t['Survived']==t['Survived_p']])/len(t),5)
print(accuracy)

上述是随机森林的代码,如上所述,随机森林是一系列决策树的组合,决策树每次分裂,用Gini系数衡量当前节点的“不纯净度”,如果按照某个特征的某个分裂点对数据集划分后,能够让数据集的Gini下降最多(显著地减少了数据集输出变量的不纯度),则选为当前最佳的分割特征及分割点。代码如下:

# -*- coding: utf-8 -*-
"""
@author: kim
"""
import pandas as pd
import numpy as np
#import sklearn as sk
import math
def tree_grow(dataframe,target,min_leaf,min_dec_gini):
  tree={} #renew a tree
  is_not_leaf=(len(dataframe)>min_leaf)
  if is_not_leaf:
    fea,sp,gd=best_split_col(dataframe,target)
    if gd>min_dec_gini:
      tree['fea']=fea
      tree['val']=sp
#      dataframe.drop(fea,axis=1) #1116 modified
      l,r=dataSplit(dataframe,fea,sp)
      l.drop(fea,axis=1)
      r.drop(fea,axis=1)
      tree['left']=tree_grow(l,target,min_leaf,min_dec_gini)
      tree['right']=tree_grow(r,target,min_leaf,min_dec_gini)
    else:#return a leaf
      return leaf(dataframe[target])
  else:
    return leaf(dataframe[target])
  return tree
def leaf(class_lable):
  tmp={}
  for i in class_lable:
    if i in tmp:
      tmp[i]+=1
    else:
      tmp[i]=1
  s=pd.Series(tmp)
  s.sort(ascending=False)
  return s.index[0]
def gini_cal(class_lable):
  p_1=sum(class_lable)/len(class_lable)
  p_0=1-p_1
  gini=1-(pow(p_0,2)+pow(p_1,2))
  return gini
def dataSplit(dataframe,split_fea,split_val):
  left_node=dataframe[dataframe[split_fea]<=split_val]
  right_node=dataframe[dataframe[split_fea]>split_val]
  return left_node,right_node
def best_split_col(dataframe,target_name):
  best_fea=''#modified 1116
  best_split_point=0
  col_list=list(dataframe.columns)
  col_list.remove(target_name)
  gini_0=gini_cal(dataframe[target_name])
  n=len(dataframe)
  gini_dec=-99999999
  for col in col_list:
    node=dataframe[[col,target_name]]
    unique=node.groupby(col).count().index
    for split_point in unique: #unique value
      left_node,right_node=dataSplit(node,col,split_point)
      if len(left_node)>0 and len(right_node)>0:
        gini_col=gini_cal(left_node[target_name])*(len(left_node)/n)+gini_cal(right_node[target_name])*(len(right_node)/n)
        if (gini_0-gini_col)>gini_dec:
          gini_dec=gini_0-gini_col#decrease of impurity
          best_fea=col
          best_split_point=split_point
    #print(col,split_point,gini_0-gini_col)
  return best_fea,best_split_point,gini_dec
def model_prediction(model,row): #row is a df
  fea=model['fea']
  val=model['val']
  left=model['left']
  right=model['right']
  if row[fea].tolist()[0]<=val:#get the value
    branch=left
  else:
    branch=right
  if ('dict' in str( type(branch) )):
    prediction=model_prediction(branch,row)
  else:
    prediction=branch
  return prediction

实际上,上面的代码还有很大的效率提升的空间,数据集不是很大的情况下,如果选择一个较大的输入参数,例如生成100棵树,就会显著地变慢;同时,将预测结果提交至kaggle进行评测,发现在测试集上的正确率不是很高,比使用sklearn里面相应的包进行预测的正确率(0.77512)要稍低一点 :-(  如果要提升准确率,两个大方向: 构造新的特征;调整现有模型的参数。

这里是抛砖引玉,欢迎大家对我的建模思路和算法的实现方法提出修改意见。

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python实现冒泡,插入,选择排序简单实例
Aug 18 Python
python根据路径导入模块的方法
Sep 30 Python
使用Python下载歌词并嵌入歌曲文件中的实现代码
Nov 13 Python
用Python写冒泡排序代码
Apr 12 Python
简单学习Python time模块
Apr 29 Python
Django模板变量如何传递给外部js调用的方法小结
Jul 24 Python
python创建学生成绩管理系统
Nov 22 Python
Python基础之列表常见操作经典实例详解
Feb 26 Python
python -v 报错问题的解决方法
Sep 15 Python
Django数据库迁移常见使用方法
Nov 12 Python
python实现简单倒计时功能
Apr 21 Python
OpenCV全景图像拼接的实现示例
Jun 05 Python
Python决策树和随机森林算法实例详解
Jan 30 #Python
在Python 2.7即将停止支持时,我们为你带来了一份python 3.x迁移指南
Jan 30 #Python
python使用Tkinter实现在线音乐播放器
Jan 30 #Python
Python字典及字典基本操作方法详解
Jan 30 #Python
Python操作MySQL数据库的三种方法总结
Jan 30 #Python
python3.5 tkinter实现页面跳转
Jan 30 #Python
python 连接各类主流数据库的实例代码
Jan 30 #Python
You might like
可以在线执行PHP代码包装修正版
2008/03/15 PHP
PHP开发之归档格式phar文件概念与用法详解【创建,使用,解包还原提取】
2017/11/17 PHP
关于jQuery的inArray 方法介绍
2011/10/08 Javascript
滚动图片效果 jquery实现回旋滚动效果
2013/01/08 Javascript
Js,alert出现乱码问题的解决方法
2013/06/19 Javascript
window.open 以post方式传递参数示例代码
2014/02/27 Javascript
jquery获取当前点击对象的value方法
2014/02/28 Javascript
javascript模拟实现ajax加载框实例
2014/10/15 Javascript
jQuery前端框架easyui使用Dialog时bug处理
2014/12/05 Javascript
Bootstrap每天必学之基础排版
2015/11/20 Javascript
jQuery过滤选择器经典应用
2016/08/18 Javascript
jQuery使用deferreds串行多个ajax请求
2016/08/22 Javascript
jQuery web 组件 后台日历价格、库存设置的代码
2016/10/14 Javascript
微信小程序 标签传入数据
2017/05/08 Javascript
不得不看之JavaScript构造函数及new运算符
2017/08/21 Javascript
基于Datatables跳转到指定页的简单实例
2017/11/09 Javascript
js获取html页面代码中图片地址的实现代码
2018/03/05 Javascript
es6新特性之 class 基本用法解析
2018/05/05 Javascript
微信小程序仿朋友圈发布动态功能
2018/07/15 Javascript
vue2.0 如何在hash模式下实现微信分享
2019/01/22 Javascript
uniapp,微信小程序中使用 MQTT的问题
2020/07/11 Javascript
多个应用共存的Django配置方法
2018/05/30 Python
通过python将大量文件按修改时间分类的方法
2018/10/17 Python
WxPython建立批量录入框窗口
2019/02/27 Python
Django中如何使用sass的方法步骤
2019/07/09 Python
解决Django中多条件查询的问题
2019/07/18 Python
np.newaxis 实现为 numpy.ndarray(多维数组)增加一个轴
2019/11/30 Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
2020/01/18 Python
用你熟悉的语言写一个连接ORACLE数据库的程序,能够完成修改和查询工作
2012/06/11 面试题
Linux面试题LINUX系统类
2014/11/19 面试题
旅行社各个岗位职责
2014/03/15 职场文书
毕业留言寄语大全
2014/04/10 职场文书
英语演讲稿3分钟
2014/04/29 职场文书
欢迎领导检查标语
2014/06/27 职场文书
用python开发一款操作MySQL的小工具
2021/05/12 Python
使用Redis实现点赞取消点赞的详细代码
2022/03/20 Redis