python实现决策树、随机森林的简单原理


Posted in Python onMarch 26, 2018

本文申明:此文为学习记录过程,中间多处引用大师讲义和内容。

一、概念

决策树(Decision Tree)是一种简单但是广泛使用的分类器。通过训练数据构建决策树,可以高效的对未知的数据进行分类。决策数有两大优点:1)决策树模型可以读性好,具有描述性,有助于人工分析;2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

看了一遍概念后,我们先从一个简单的案例开始,如下图我们样本:

python实现决策树、随机森林的简单原理

对于上面的样本数据,根据不同特征值我们最后是选择是否约会,我们先自定义的一个决策树,决策树如下图所示:

python实现决策树、随机森林的简单原理

对于上图中的决策树,有个疑问,就是为什么第一个选择是“长相”这个特征,我选择“收入”特征作为第一分类的标准可以嘛?下面我们就对构建决策树选择特征的问题进行讨论;在考虑之前我们要先了解一下相关的数学知识:

    信息熵:熵代表信息的不确定性,信息的不确定性越大,熵越大;比如“明天太阳从东方升起”这一句话代表的信息我们可以认为为0;因为太阳从东方升起是个特定的规律,我们可以把这个事件的信息熵约等于0;说白了,信息熵和事件发生的概率成反比:数学上把信息熵定义如下:H(X)=H(P1,P2,…,Pn)=-∑P(xi)logP(xi)

   互信息:指的是两个随机变量之间的关联程度,即给定一个随机变量后,另一个随机变量不确定性的削弱程度,因而互信息取值最小为0,意味着给定一个随机变量对确定一另一个随机变量没有关系,最大取值为随机变量的熵,意味着给定一个随机变量,能完全消除另一个随机变量的不确定性

现在我们就把信息熵运用到决策树特征选择上,对于选择哪个特征我们按照这个规则进行“哪个特征能使信息的确定性最大我们就选择哪个特征”;比如上图的案例中;

第一步:假设约会去或不去的的事件为Y,其信息熵为H(Y);

第二步:假设给定特征的条件下,其条件信息熵分别为H(Y|长相),H(Y|收入),H(Y|身高)

第三步:分别计算信息增益(互信息):G(Y,长相) = I(Y,长相) = H(Y)-H(Y|长相) 、G(Y,) = I(Y,长相) = H(Y)-H(Y|长相)等

第四部:选择信息增益最大的特征作为分类特征;因为增益信息大的特征意味着给定这个特征,能很大的消除去约会还是不约会的不确定性;

第五步:迭代选择特征即可;

按以上就解决了决策树的分类特征选择问题,上面的这种方法就是ID3方法,当然还是别的方法如 C4.5;等;

二、决策树的过拟合解决办法

   若决策树的度过深的话会出现过拟合现象,对于决策树的过拟合有二个方案:

   1.剪枝-先剪枝和后剪纸(可以在构建决策树的时候通过指定深度,每个叶子的样本数来达到剪枝的作用)

   2.随机森林 --构建大量的决策树组成森林来防止过拟合;虽然单个树可能存在过拟合,但通过广度的增加就会消除过拟合现象

三、随机森林

随机森林是一个最近比较火的算法,它有很多的优点:

  • 在数据集上表现良好
  • 在当前的很多数据集上,相对其他算法有着很大的优势
  • 它能够处理很高维度(feature很多)的数据,并且不用做特征选择
  • 在训练完后,它能够给出哪些feature比较重要
  • 训练速度快
  • 在训练过程中,能够检测到feature间的互相影响
  • 容易做成并行化方法
  • 实现比较简单

随机森林顾名思义,是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。在得到森林之后,当有一个新的输入样本进入的时候,就让森林中的每一棵决策树分别进行一下判断,看看这个样本应该属于哪一类(对于分类算法),然后看看哪一类被选择最多,就预测这个样本为那一类。

上一段决策树代码:

# 花萼长度、花萼宽度,花瓣长度,花瓣宽度 
iris_feature_E = 'sepal length', 'sepal width', 'petal length', 'petal width' 
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度' 
iris_class = 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica' 
 
 
if __name__ == "__main__": 
  mpl.rcParams['font.sans-serif'] = [u'SimHei'] 
  mpl.rcParams['axes.unicode_minus'] = False 
 
  path = '..\\8.Regression\\iris.data' # 数据文件路径 
  data = pd.read_csv(path, header=None) 
  x = data[range(4)] 
  y = pd.Categorical(data[4]).codes 
  # 为了可视化,仅使用前两列特征 
  x = x.iloc[:, :2] 
  x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=1) 
  print y_test.shape 
 
  # 决策树参数估计 
  # min_samples_split = 10:如果该结点包含的样本数目大于10,则(有可能)对其分支 
  # min_samples_leaf = 10:若将某结点分支后,得到的每个子结点样本数目都大于10,则完成分支;否则,不进行分支 
  model = DecisionTreeClassifier(criterion='entropy') 
  model.fit(x_train, y_train) 
  y_test_hat = model.predict(x_test)   # 测试数据 
 
  # 保存 
  # dot -Tpng my.dot -o my.png 
  # 1、输出 
  with open('iris.dot', 'w') as f: 
    tree.export_graphviz(model, out_file=f) 
  # 2、给定文件名 
  # tree.export_graphviz(model, out_file='iris1.dot') 
  # 3、输出为pdf格式 
  dot_data = tree.export_graphviz(model, out_file=None, feature_names=iris_feature_E, class_names=iris_class, 
                  filled=True, rounded=True, special_characters=True) 
  graph = pydotplus.graph_from_dot_data(dot_data) 
  graph.write_pdf('iris.pdf') 
  f = open('iris.png', 'wb') 
  f.write(graph.create_png()) 
  f.close() 
 
  # 画图 
  N, M = 50, 50 # 横纵各采样多少个值 
  x1_min, x2_min = x.min() 
  x1_max, x2_max = x.max() 
  t1 = np.linspace(x1_min, x1_max, N) 
  t2 = np.linspace(x2_min, x2_max, M) 
  x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点 
  x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点 
  print x_show.shape 
 
  # # 无意义,只是为了凑另外两个维度 
  # # 打开该注释前,确保注释掉x = x[:, :2] 
  # x3 = np.ones(x1.size) * np.average(x[:, 2]) 
  # x4 = np.ones(x1.size) * np.average(x[:, 3]) 
  # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1) # 测试点 
 
  cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF']) 
  cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b']) 
  y_show_hat = model.predict(x_show) # 预测值 
  print y_show_hat.shape 
  print y_show_hat 
  y_show_hat = y_show_hat.reshape(x1.shape) # 使之与输入的形状相同 
  print y_show_hat 
  plt.figure(facecolor='w') 
  plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 预测值的显示 
  plt.scatter(x_test[0], x_test[1], c=y_test.ravel(), edgecolors='k', s=150, zorder=10, cmap=cm_dark, marker='*') # 测试数据 
  plt.scatter(x[0], x[1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark) # 全部数据 
  plt.xlabel(iris_feature[0], fontsize=15) 
  plt.ylabel(iris_feature[1], fontsize=15) 
  plt.xlim(x1_min, x1_max) 
  plt.ylim(x2_min, x2_max) 
  plt.grid(True) 
  plt.title(u'鸢尾花数据的决策树分类', fontsize=17) 
  plt.show()

python实现决策树、随机森林的简单原理

python实现决策树、随机森林的简单原理

以上就是决策树做分类,但决策树也可以用来做回归,不说直接上代码:

if __name__ == "__main__": 
  N =100 
  x = np.random.rand(N) *6 -3 
  x.sort() 
  y = np.sin(x) + np.random.randn(N) *0.05 
  x = x.reshape(-1,1) 
  print x 
  dt = DecisionTreeRegressor(criterion='mse',max_depth=9) 
  dt.fit(x,y) 
  x_test = np.linspace(-3,3,50).reshape(-1,1) 
  y_hat = dt.predict(x_test) 
 
  plt.plot(x,y,'r*',ms =5,label='Actual') 
  plt.plot(x_test,y_hat,'g-',linewidth=2,label='predict') 
  plt.legend(loc ='upper left') 
  plt.grid() 
  plt.show() 
 
  #比较决策树的深度影响 
  depth =[2,4,6,8,10] 
  clr = 'rgbmy' 
  dtr = DecisionTreeRegressor(criterion='mse') 
  plt.plot(x,y,'ko',ms=6,label='Actual') 
  x_test = np.linspace(-3,3,50).reshape(-1,1) 
  for d,c in zip(depth,clr): 
    dtr.set_params(max_depth=d) 
    dtr.fit(x,y) 
    y_hat = dtr.predict(x_test) 
    plt.plot(x_test,y_hat,'-',color=c,linewidth =2,label='Depth=%d' % d) 
  plt.legend(loc='upper left') 
  plt.grid(b =True) 
  plt.show()

python实现决策树、随机森林的简单原理

不同深度对回归的 影响如下图:

python实现决策树、随机森林的简单原理

下面上个随机森林代码

mpl.rcParams['font.sans-serif'] = [u'SimHei'] # 黑体 FangSong/KaiTi 
mpl.rcParams['axes.unicode_minus'] = False 
 
path = 'iris.data' # 数据文件路径 
data = pd.read_csv(path, header=None) 
x_prime = data[range(4)] 
y = pd.Categorical(data[4]).codes 
feature_pairs = [[0, 1]] 
plt.figure(figsize=(10,9),facecolor='#FFFFFF') 
for i,pair in enumerate(feature_pairs): 
  x = x_prime[pair] 
  clf = RandomForestClassifier(n_estimators=200,criterion='entropy',max_depth=3) 
  clf.fit(x,y.ravel()) 
  N, M =50,50 
  x1_min,x2_min = x.min() 
  x1_max,x2_max = x.max() 
  t1 = np.linspace(x1_min,x1_max, N) 
  t2 = np.linspace(x2_min,x2_max, M) 
  x1,x2 = np.meshgrid(t1,t2) 
  x_test = np.stack((x1.flat,x2.flat),axis =1) 
  y_hat = clf.predict(x) 
  y = y.reshape(-1) 
  c = np.count_nonzero(y_hat == y) 
  print '特征:',iris_feature[pair[0]],'+',iris_feature[pair[1]] 
  print '\t 预测正确数目:',c 
  cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF']) 
  cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b']) 
  y_hat = clf.predict(x_test) 
  y_hat = y_hat.reshape(x1.shape) 
  plt.pcolormesh(x1,x2,y_hat,cmap =cm_light) 
  plt.scatter(x[pair[0]],x[pair[1]],c=y,edgecolors='k',cmap=cm_dark) 
  plt.xlabel(iris_feature[pair[0]],fontsize=12) 
  plt.ylabel(iris_feature[pair[1]], fontsize=14) 
  plt.xlim(x1_min, x1_max) 
  plt.ylim(x2_min, x2_max) 
  plt.grid() 
plt.tight_layout(2.5) 
plt.subplots_adjust(top=0.92) 
plt.suptitle(u'随机森林对鸢尾花数据的两特征组合的分类结果', fontsize=18) 
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用pyhook实现键盘监控的例子
Jul 18 Python
python条件变量之生产者与消费者操作实例分析
Mar 22 Python
python实现折半查找和归并排序算法
Apr 14 Python
浅谈python迭代器
Nov 08 Python
详解appium+python 启动一个app步骤
Dec 20 Python
解决seaborn在pycharm中绘图不出图的问题
May 24 Python
python 移动图片到另外一个文件夹的实例
Jan 10 Python
关于python字符串方法分类详解
Aug 20 Python
在主流系统之上安装Pygame的方法
May 20 Python
Python requests库参数提交的注意事项总结
Mar 29 Python
python机器学习Github已达8.9Kstars模型解释器LIME
Nov 23 Python
asyncio异步编程之Task对象详解
Mar 13 Python
python机器学习之贝叶斯分类
Mar 26 #Python
利用python实现微信头像加红色数字功能
Mar 26 #Python
Python扩展内置类型详解
Mar 26 #Python
python函数式编程学习之yield表达式形式详解
Mar 25 #Python
Python实现简单求解给定整数的质因数算法示例
Mar 25 #Python
python实现隐马尔科夫模型HMM
Mar 25 #Python
Python实现的寻找前5个默尼森数算法示例
Mar 25 #Python
You might like
PHP session常见问题集锦及解决办法总结
2007/03/18 PHP
PHP的简易冒泡法代码分享
2012/08/28 PHP
Thinkphp的volist标签嵌套循环使用教程
2014/07/08 PHP
使用PHP进行微信公众平台开发的示例
2015/08/21 PHP
PHP使用MPDF类生成PDF的方法
2015/12/08 PHP
实例讲解PHP表单验证功能
2019/02/15 PHP
JQuery UI皮肤定制
2009/07/27 Javascript
poshytip 基于jquery的 插件 主要用于显示微博人的图像和鼠标提示等
2012/10/12 Javascript
IE网页js语法错误2行字符1、FF中正常的解决方法
2013/09/09 Javascript
简单的代码实现jquery定时器
2014/01/03 Javascript
jQuery中slideUp()方法用法分析
2014/12/24 Javascript
jQuery插件实现大图全屏图片相册
2015/03/14 Javascript
JavaScript实现动态删除列表框值的方法
2015/08/12 Javascript
《JavaScript高级编程》学习笔记之object和array引用类型
2015/11/01 Javascript
Bootstrap table两种分页示例
2016/12/23 Javascript
AngularJS ionic手势事件的使用总结
2017/08/09 Javascript
基于layui数据表格以及传数据的方式
2018/08/19 Javascript
小程序接入腾讯位置服务的详细流程
2020/03/03 Javascript
Vue之封装公用变量以及实现方式
2020/07/31 Javascript
vue任意关系组件通信与跨组件监听状态vue-communication
2020/10/18 Javascript
JavaScript实现简易计算器小功能
2020/10/22 Javascript
Python Sql数据库增删改查操作简单封装
2016/04/18 Python
Python中with及contextlib的用法详解
2017/06/08 Python
python实现时间o(1)的最小栈的实例代码
2018/07/23 Python
Python实现K折交叉验证法的方法步骤
2019/07/11 Python
解决pymysql cursor.fetchall() 获取不到数据的问题
2020/05/15 Python
体验完美剃须:The Art of Shaving
2018/08/06 全球购物
设备动力科岗位职责范本
2014/02/23 职场文书
体育活动总结范文
2014/05/04 职场文书
2015年社区党建工作汇报材料
2015/06/25 职场文书
同学聚会感言一句话
2015/07/30 职场文书
大队委员竞选演讲稿
2015/11/20 职场文书
2016幼儿园教师节新闻稿
2015/11/25 职场文书
2016小学优秀教师先进事迹材料
2016/02/26 职场文书
聊聊redis-dump工具安装问题
2022/01/18 Redis
python中urllib包的网络请求教程
2022/04/19 Python