Python实现的随机森林算法与简单总结


Posted in Python onJanuary 30, 2018

本文实例讲述了Python实现的随机森林算法。分享给大家供大家参考,具体如下:

随机森林是数据挖掘中非常常用的分类预测算法,以分类或回归的决策树为基分类器。算法的一些基本要点:

*对大小为m的数据集进行样本量同样为m的有放回抽样;
*对K个特征进行随机抽样,形成特征的子集,样本量的确定方法可以有平方根、自然对数等;
*每棵树完全生成,不进行剪枝;
*每个样本的预测结果由每棵树的预测投票生成(回归的时候,即各棵树的叶节点的平均)

著名的python机器学习包scikit learn的文档对此算法有比较详尽的介绍: http://scikit-learn.org/stable/modules/ensemble.html#random-forests

出于个人研究和测试的目的,基于经典的Kaggle 101泰坦尼克号乘客的数据集,建立模型并进行评估。比赛页面及相关数据集的下载:https://www.kaggle.com/c/titanic

泰坦尼克号的沉没,是历史上非常著名的海难。突然感到,自己面对的不再是冷冰冰的数据,而是用数据挖掘的方法,去研究具体的历史问题,也是饶有兴趣。言归正传,模型的主要的目标,是希望根据每个乘客的一系列特征,如性别、年龄、舱位、上船地点等,对其是否能生还进行预测,是非常典型的二分类预测问题。数据集的字段名及实例如下:

PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
1 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.25 S
2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Thayer) female 38 1 0 PC 17599 71.2833 C85 C
3 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.925 S
4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35 1 0 113803 53.1 C123 S
5 0 3 Allen, Mr. William Henry male 35 0 0 373450 8.05 S

值得说明的是,SibSp是指sister brother spouse,即某个乘客随行的兄弟姐妹、丈夫、妻子的人数,Parch指parents,children

下面给出整个数据处理及建模过程,基于ubuntu+python 3.4( anaconda科学计算环境已经集成一系列常用包,pandas numpy sklearn等,这里强烈推荐)

懒得切换输入法,写的时候主要的注释都是英文,中文的注释是后来补充的:-)

# -*- coding: utf-8 -*-
"""
@author: kim
"""
from model import *#载入基分类器的代码
#ETL:same procedure to training set and test set
training=pd.read_csv('train.csv',index_col=0)
test=pd.read_csv('test.csv',index_col=0)
SexCode=pd.DataFrame([1,0],index=['female','male'],columns=['Sexcode']) #将性别转化为01
training=training.join(SexCode,how='left',on=training.Sex)
training=training.drop(['Name','Ticket','Embarked','Cabin','Sex'],axis=1)#删去几个不参与建模的变量,包括姓名、船票号,船舱号
test=test.join(SexCode,how='left',on=test.Sex)
test=test.drop(['Name','Ticket','Embarked','Cabin','Sex'],axis=1)
print('ETL IS DONE!')
#MODEL FITTING
#===============PARAMETER AJUSTMENT============
min_leaf=1
min_dec_gini=0.0001
n_trees=5
n_fea=int(math.sqrt(len(training.columns)-1))
#==============================================
'''''
BEST SCORE:0.83
min_leaf=30
min_dec_gini=0.001
n_trees=20
'''
#ESSEMBLE BY RANDOM FOREST
FOREST={}
tmp=list(training.columns)
tmp.pop(tmp.index('Survived'))
feaList=pd.Series(tmp)
for t in range(n_trees):
#  fea=[]
  feasample=feaList.sample(n=n_fea,replace=False)#select feature
  fea=feasample.tolist()
  fea.append('Survived')
#    feaNew=fea.append(target)
  subset=training.sample(n=len(training),replace=True)#generate the dataset with replacement
  subset=subset[fea]
#  print(str(t)+' Classifier built on feature:')
#  print(list(fea))
  FOREST[t]=tree_grow(subset,'Survived',min_leaf,min_dec_gini) #save the tree
#MODEL PREDICTION
#======================
currentdata=training
output='submission_rf_20151116_30_0.001_20'
#======================
prediction={}
for r in currentdata.index:#a row
  prediction_vote={1:0,0:0}
  row=currentdata.get(currentdata.index==r)
  for n in range(n_trees):
    tree_dict=FOREST[n] #a tree
    p=model_prediction(tree_dict,row)
    prediction_vote[p]+=1
  vote=pd.Series(prediction_vote)
  prediction[r]=list(vote.order(ascending=False).index)[0]#the vote result
result=pd.Series(prediction,name='Survived_p')
#del prediction_vote
#del prediction
#result.to_csv(output)
t=training.join(result,how='left')
accuracy=round(len(t[t['Survived']==t['Survived_p']])/len(t),5)
print(accuracy)

上述是随机森林的代码,如上所述,随机森林是一系列决策树的组合,决策树每次分裂,用Gini系数衡量当前节点的“不纯净度”,如果按照某个特征的某个分裂点对数据集划分后,能够让数据集的Gini下降最多(显著地减少了数据集输出变量的不纯度),则选为当前最佳的分割特征及分割点。代码如下:

# -*- coding: utf-8 -*-
"""
@author: kim
"""
import pandas as pd
import numpy as np
#import sklearn as sk
import math
def tree_grow(dataframe,target,min_leaf,min_dec_gini):
  tree={} #renew a tree
  is_not_leaf=(len(dataframe)>min_leaf)
  if is_not_leaf:
    fea,sp,gd=best_split_col(dataframe,target)
    if gd>min_dec_gini:
      tree['fea']=fea
      tree['val']=sp
#      dataframe.drop(fea,axis=1) #1116 modified
      l,r=dataSplit(dataframe,fea,sp)
      l.drop(fea,axis=1)
      r.drop(fea,axis=1)
      tree['left']=tree_grow(l,target,min_leaf,min_dec_gini)
      tree['right']=tree_grow(r,target,min_leaf,min_dec_gini)
    else:#return a leaf
      return leaf(dataframe[target])
  else:
    return leaf(dataframe[target])
  return tree
def leaf(class_lable):
  tmp={}
  for i in class_lable:
    if i in tmp:
      tmp[i]+=1
    else:
      tmp[i]=1
  s=pd.Series(tmp)
  s.sort(ascending=False)
  return s.index[0]
def gini_cal(class_lable):
  p_1=sum(class_lable)/len(class_lable)
  p_0=1-p_1
  gini=1-(pow(p_0,2)+pow(p_1,2))
  return gini
def dataSplit(dataframe,split_fea,split_val):
  left_node=dataframe[dataframe[split_fea]<=split_val]
  right_node=dataframe[dataframe[split_fea]>split_val]
  return left_node,right_node
def best_split_col(dataframe,target_name):
  best_fea=''#modified 1116
  best_split_point=0
  col_list=list(dataframe.columns)
  col_list.remove(target_name)
  gini_0=gini_cal(dataframe[target_name])
  n=len(dataframe)
  gini_dec=-99999999
  for col in col_list:
    node=dataframe[[col,target_name]]
    unique=node.groupby(col).count().index
    for split_point in unique: #unique value
      left_node,right_node=dataSplit(node,col,split_point)
      if len(left_node)>0 and len(right_node)>0:
        gini_col=gini_cal(left_node[target_name])*(len(left_node)/n)+gini_cal(right_node[target_name])*(len(right_node)/n)
        if (gini_0-gini_col)>gini_dec:
          gini_dec=gini_0-gini_col#decrease of impurity
          best_fea=col
          best_split_point=split_point
    #print(col,split_point,gini_0-gini_col)
  return best_fea,best_split_point,gini_dec
def model_prediction(model,row): #row is a df
  fea=model['fea']
  val=model['val']
  left=model['left']
  right=model['right']
  if row[fea].tolist()[0]<=val:#get the value
    branch=left
  else:
    branch=right
  if ('dict' in str( type(branch) )):
    prediction=model_prediction(branch,row)
  else:
    prediction=branch
  return prediction

实际上,上面的代码还有很大的效率提升的空间,数据集不是很大的情况下,如果选择一个较大的输入参数,例如生成100棵树,就会显著地变慢;同时,将预测结果提交至kaggle进行评测,发现在测试集上的正确率不是很高,比使用sklearn里面相应的包进行预测的正确率(0.77512)要稍低一点 :-(  如果要提升准确率,两个大方向: 构造新的特征;调整现有模型的参数。

这里是抛砖引玉,欢迎大家对我的建模思路和算法的实现方法提出修改意见。

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python break语句详解
Mar 11 Python
python使用ctypes模块调用windowsapi获取系统版本示例
Apr 17 Python
Python文档生成工具pydoc使用介绍
Jun 02 Python
Python对文件操作知识汇总
May 15 Python
python 遍历字符串(含汉字)实例详解
Apr 04 Python
对Python3.6 IDLE常用快捷键介绍
Jul 16 Python
python文件操作之批量修改文件后缀名的方法
Aug 10 Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 Python
tensorflow2.0的函数签名与图结构(推荐)
Apr 28 Python
Django权限设置及验证方式
May 13 Python
PyInstaller运行原理及常用操作详解
Jun 13 Python
python 两种方法修改文件的创建时间、修改时间、访问时间
Sep 26 Python
Python决策树和随机森林算法实例详解
Jan 30 #Python
在Python 2.7即将停止支持时,我们为你带来了一份python 3.x迁移指南
Jan 30 #Python
python使用Tkinter实现在线音乐播放器
Jan 30 #Python
Python字典及字典基本操作方法详解
Jan 30 #Python
Python操作MySQL数据库的三种方法总结
Jan 30 #Python
python3.5 tkinter实现页面跳转
Jan 30 #Python
python 连接各类主流数据库的实例代码
Jan 30 #Python
You might like
thinkPHP模板引擎用法示例
2016/12/08 PHP
Firefox和IE浏览器兼容JS脚本写法小结
2008/07/07 Javascript
(jQuery,mootools,dojo)使用适合自己的编程别名命名
2010/09/14 Javascript
使用Jquery实现点击文字后变成文本框且可修改
2013/09/21 Javascript
JS实现根据文件字节数返回文件大小的方法
2016/08/02 Javascript
JS定时检测任务任务完成后执行下一步的解决办法
2016/12/22 Javascript
vue2.0结合DataTable插件实现表格动态刷新的方法详解
2017/03/17 Javascript
JS自定义函数实现时间戳转换成date的方法示例
2017/08/27 Javascript
nodejs实现超简单生成二维码的方法
2018/03/17 NodeJs
JS构造一个html文本内容成文件流形式发送到后台
2018/07/31 Javascript
Vue学习之axios的使用方法实例分析
2020/01/06 Javascript
JS实现滑动导航效果
2020/01/14 Javascript
Vue v-for中的 input 或 select的值发生改变时触发事件操作
2020/08/31 Javascript
用webAPI实现图片放大镜效果
2020/11/23 Javascript
[20:57]Ti4主赛事第三天开幕式
2014/07/21 DOTA
Python脚本简单实现打开默认浏览器登录人人和打开QQ的方法
2016/04/12 Python
Python简单定义与使用字典dict的方法示例
2017/07/25 Python
Python cookbook(数据结构与算法)根据字段将记录分组操作示例
2018/03/19 Python
python实现全盘扫描搜索功能的方法
2019/02/14 Python
为什么你还不懂得怎么使用Python协程
2019/05/13 Python
python 正则表达式贪婪模式与非贪婪模式原理、用法实例分析
2019/10/14 Python
Python利用全连接神经网络求解MNIST问题详解
2020/01/14 Python
django使用JWT保存用户登录信息
2020/04/22 Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
2020/06/24 Python
如何用Django处理gzip数据流
2021/01/29 Python
HTML5的自定义属性data-*详细介绍和JS操作实例
2014/04/10 HTML / CSS
详解Canvas实用库Fabric.js使用手册
2019/01/07 HTML / CSS
野兽派官方旗舰店:THE BEAST 野兽派
2016/08/05 全球购物
澳大利亚二手奢侈品网站:Modsie
2019/09/23 全球购物
java程序员面试交流
2012/11/29 面试题
小学生倡议书范文
2014/05/13 职场文书
档案工作汇报材料
2014/08/21 职场文书
2014年银行个人工作总结
2014/12/05 职场文书
2015公务员试用期工作总结
2014/12/12 职场文书
2016年母亲节寄语
2015/12/04 职场文书
Python torch.flatten()函数案例详解
2021/08/30 Python