python实现AdaBoost算法的示例


Posted in Python onOctober 03, 2020

代码

'''
数据集:Mnist
训练集数量:60000(实际使用:10000)
测试集数量:10000(实际使用:1000)
层数:40
------------------------------
运行结果:
  正确率:97%
  运行时长:65m
'''

import time
import numpy as np


def loadData(fileName):
  '''
  加载文件
  :param fileName:要加载的文件路径
  :return: 数据集和标签集
  '''
  # 存放数据及标记
  dataArr = []
  labelArr = []
  # 读取文件
  fr = open(fileName)
  # 遍历文件中的每一行
  for line in fr.readlines():
    # 获取当前行,并按“,”切割成字段放入列表中
    # strip:去掉每行字符串首尾指定的字符(默认空格或换行符)
    # split:按照指定的字符将字符串切割成每个字段,返回列表形式
    curLine = line.strip().split(',')
    # 将每行中除标记外的数据放入数据集中(curLine[0]为标记信息)
    # 在放入的同时将原先字符串形式的数据转换为整型
    # 此外将数据进行了二值化处理,大于128的转换成1,小于的转换成0,方便后续计算
    dataArr.append([int(int(num) > 128) for num in curLine[1:]])
    # 将标记信息放入标记集中
    # 放入的同时将标记转换为整型

    # 转换成二分类任务
    # 标签0设置为1,反之为-1
    if int(curLine[0]) == 0:
      labelArr.append(1)
    else:
      labelArr.append(-1)
  # 返回数据集和标记
  return dataArr, labelArr


def calc_e_Gx(trainDataArr, trainLabelArr, n, div, rule, D):
  '''
  计算分类错误率
  :param trainDataArr:训练数据集数字
  :param trainLabelArr: 训练标签集数组
  :param n: 要操作的特征
  :param div:划分点
  :param rule:正反例标签
  :param D:权值分布D
  :return:预测结果, 分类误差率
  '''
  # 初始化分类误差率为0
  e = 0
  # 将训练数据矩阵中特征为n的那一列单独剥出来做成数组。因为其他元素我们并不需要,
  # 直接对庞大的训练集进行操作的话会很慢
  x = trainDataArr[:, n]
  # 同样将标签也转换成数组格式,x和y的转换只是单纯为了提高运行速度
  # 测试过相对直接操作而言性能提升很大
  y = trainLabelArr
  predict = []

  # 依据小于和大于的标签依据实际情况会不同,在这里直接进行设置
  if rule == 'LisOne':
    L = 1
    H = -1
  else:
    L = -1
    H = 1

  # 遍历所有样本的特征m
  for i in range(trainDataArr.shape[0]):
    if x[i] < div:
      # 如果小于划分点,则预测为L
      # 如果设置小于div为1,那么L就是1,
      # 如果设置小于div为-1,L就是-1
      predict.append(L)
      # 如果预测错误,分类错误率要加上该分错的样本的权值(8.1式)
      if y[i] != L:
        e += D[i]
    elif x[i] >= div:
      # 与上面思想一样
      predict.append(H)
      if y[i] != H:
        e += D[i]
  # 返回预测结果和分类错误率e
  # 预测结果其实是为了后面做准备的,在算法8.1第四步式8.4中exp内部有个Gx,要用在那个地方
  # 以此来更新新的D
  return np.array(predict), e


def createSigleBoostingTree(trainDataArr, trainLabelArr, D):
  '''
  创建单层提升树
  :param trainDataArr:训练数据集数组
  :param trainLabelArr: 训练标签集数组
  :param D: 算法8.1中的D
  :return: 创建的单层提升树
  '''

  # 获得样本数目及特征数量
  m, n = np.shape(trainDataArr)
  # 单层树的字典,用于存放当前层提升树的参数
  # 也可以认为该字典代表了一层提升树
  sigleBoostTree = {}
  # 初始化分类误差率,分类误差率在算法8.1步骤(2)(b)有提到
  # 误差率最高也只能100%,因此初始化为1
  sigleBoostTree['e'] = 1

  # 对每一个特征进行遍历,寻找用于划分的最合适的特征
  for i in range(n):
    # 因为特征已经经过二值化,只能为0和1,因此分切分时分为-0.5, 0.5, 1.5三挡进行切割
    for div in [-0.5, 0.5, 1.5]:
      # 在单个特征内对正反例进行划分时,有两种情况:
      # 可能是小于某值的为1,大于某值得为-1,也可能小于某值得是-1,反之为1
      # 因此在寻找最佳提升树的同时对于两种情况也需要遍历运行
      # LisOne:Low is one:小于某值得是1
      # HisOne:High is one:大于某值得是1
      for rule in ['LisOne', 'HisOne']:
        # 按照第i个特征,以值div进行切割,进行当前设置得到的预测和分类错误率
        Gx, e = calc_e_Gx(trainDataArr, trainLabelArr, i, div, rule, D)
        # 如果分类错误率e小于当前最小的e,那么将它作为最小的分类错误率保存
        if e < sigleBoostTree['e']:
          sigleBoostTree['e'] = e
          # 同时也需要存储最优划分点、划分规则、预测结果、特征索引
          # 以便进行D更新和后续预测使用
          sigleBoostTree['div'] = div
          sigleBoostTree['rule'] = rule
          sigleBoostTree['Gx'] = Gx
          sigleBoostTree['feature'] = i
  # 返回单层的提升树
  return sigleBoostTree


def createBosstingTree(trainDataList, trainLabelList, treeNum=50):
  '''
  创建提升树
  创建算法依据“8.1.2 AdaBoost算法” 算法8.1
  :param trainDataList:训练数据集
  :param trainLabelList: 训练测试集
  :param treeNum: 树的层数
  :return: 提升树
  '''
  # 将数据和标签转化为数组形式
  trainDataArr = np.array(trainDataList)
  trainLabelArr = np.array(trainLabelList)
  # 没增加一层数后,当前最终预测结果列表
  finallpredict = [0] * len(trainLabelArr)
  # 获得训练集数量以及特征个数
  m, n = np.shape(trainDataArr)

  # 依据算法8.1步骤(1)初始化D为1/N
  D = [1 / m] * m
  # 初始化提升树列表,每个位置为一层
  tree = []
  # 循环创建提升树
  for i in range(treeNum):
    # 得到当前层的提升树
    curTree = createSigleBoostingTree(trainDataArr, trainLabelArr, D)
    # 根据式8.2计算当前层的alpha
    alpha = 1 / 2 * np.log((1 - curTree['e']) / curTree['e'])
    # 获得当前层的预测结果,用于下一步更新D
    Gx = curTree['Gx']
    # 依据式8.4更新D
    # 考虑到该式每次只更新D中的一个w,要循环进行更新知道所有w更新结束会很复杂(其实
    # 不是时间上的复杂,只是让人感觉每次单独更新一个很累),所以该式以向量相乘的形式,
    # 一个式子将所有w全部更新完。
    # 该式需要线性代数基础,如果不太熟练建议补充相关知识,当然了,单独更新w也一点问题
    # 没有
    # np.multiply(trainLabelArr, Gx):exp中的y*Gm(x),结果是一个行向量,内部为yi*Gm(xi)
    # np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx)):上面求出来的行向量内部全体
    # 成员再乘以-αm,然后取对数,和书上式子一样,只不过书上式子内是一个数,这里是一个向量
    # D是一个行向量,取代了式中的wmi,然后D求和为Zm
    # 书中的式子最后得出来一个数w,所有数w组合形成新的D
    # 这里是直接得到一个向量,向量内元素是所有的w
    # 本质上结果是相同的
    D = np.multiply(D, np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx))) / sum(D)
    # 在当前层参数中增加alpha参数,预测的时候需要用到
    curTree['alpha'] = alpha
    # 将当前层添加到提升树索引中。
    tree.append(curTree)

    # -----以下代码用来辅助,可以去掉---------------
    # 根据8.6式将结果加上当前层乘以α,得到目前的最终输出预测
    finallpredict += alpha * Gx
    # 计算当前最终预测输出与实际标签之间的误差
    error = sum([1 for i in range(len(trainDataList)) if np.sign(finallpredict[i]) != trainLabelArr[i]])
    # 计算当前最终误差率
    finallError = error / len(trainDataList)
    # 如果误差为0,提前退出即可,因为没有必要再计算算了
    if finallError == 0:
      return tree
    # 打印一些信息
    print('iter:%d:%d, sigle error:%.4f, finall error:%.4f' % (i, treeNum, curTree['e'], finallError))
  # 返回整个提升树
  return tree


def predict(x, div, rule, feature):
  '''
  输出单独层预测结果
  :param x: 预测样本
  :param div: 划分点
  :param rule: 划分规则
  :param feature: 进行操作的特征
  :return:
  '''
  # 依据划分规则定义小于及大于划分点的标签
  if rule == 'LisOne':
    L = 1
    H = -1
  else:
    L = -1
    H = 1

  # 判断预测结果
  if x[feature] < div:
    return L
  else:
    return H


def test(testDataList, testLabelList, tree):
  '''
  测试
  :param testDataList:测试数据集
  :param testLabelList: 测试标签集
  :param tree: 提升树
  :return: 准确率
  '''
  # 错误率计数值
  errorCnt = 0
  # 遍历每一个测试样本
  for i in range(len(testDataList)):
    # 预测结果值,初始为0
    result = 0
    # 依据算法8.1式8.6
    # 预测式子是一个求和式,对于每一层的结果都要进行一次累加
    # 遍历每层的树
    for curTree in tree:
      # 获取该层参数
      div = curTree['div']
      rule = curTree['rule']
      feature = curTree['feature']
      alpha = curTree['alpha']
      # 将当前层结果加入预测中
      result += alpha * predict(testDataList[i], div, rule, feature)
    # 预测结果取sign值,如果大于0 sign为1,反之为0
    if np.sign(result) != testLabelList[i]: 
      errorCnt += 1
  # 返回准确率
  return 1 - errorCnt / len(testDataList)


if __name__ == '__main__':
  # 开始时间
  start = time.time()

  # 获取训练集
  print('start read transSet')
  trainDataList, trainLabelList = loadData('../Mnist/mnist_train.csv')

  # 获取测试集
  print('start read testSet')
  testDataList, testLabelList = loadData('../Mnist/mnist_test.csv')

  # 创建提升树
  print('start init train')
  tree = createBosstingTree(trainDataList[:10000], trainLabelList[:10000], 40)

  # 测试
  print('start to test')
  accuracy = test(testDataList[:1000], testLabelList[:1000], tree)
  print('the accuracy is:%d' % (accuracy * 100), '%')

  # 结束时间
  end = time.time()
  print('time span:', end - start)

程序运行结果

start read transSet
start read testSet
start init train
iter:0:40, sigle error:0.0804, finall error:0.0804
iter:1:40, sigle error:0.1448, finall error:0.0804
iter:2:40, sigle error:0.1362, finall error:0.0585
iter:3:40, sigle error:0.1864, finall error:0.0667
iter:4:40, sigle error:0.2249, finall error:0.0474
iter:5:40, sigle error:0.2634, finall error:0.0437
iter:6:40, sigle error:0.2626, finall error:0.0377
iter:7:40, sigle error:0.2935, finall error:0.0361
iter:8:40, sigle error:0.3230, finall error:0.0333
iter:9:40, sigle error:0.3034, finall error:0.0361
iter:10:40, sigle error:0.3375, finall error:0.0325
iter:11:40, sigle error:0.3364, finall error:0.0340
iter:12:40, sigle error:0.3473, finall error:0.0309
iter:13:40, sigle error:0.3006, finall error:0.0294
iter:14:40, sigle error:0.3267, finall error:0.0275
iter:15:40, sigle error:0.3584, finall error:0.0288
iter:16:40, sigle error:0.3492, finall error:0.0257
iter:17:40, sigle error:0.3506, finall error:0.0256
iter:18:40, sigle error:0.3665, finall error:0.0240
iter:19:40, sigle error:0.3769, finall error:0.0251
iter:20:40, sigle error:0.3828, finall error:0.0213
iter:21:40, sigle error:0.3733, finall error:0.0229
iter:22:40, sigle error:0.3785, finall error:0.0218
iter:23:40, sigle error:0.3867, finall error:0.0219
iter:24:40, sigle error:0.3850, finall error:0.0208
iter:25:40, sigle error:0.3823, finall error:0.0201
iter:26:40, sigle error:0.3825, finall error:0.0204
iter:27:40, sigle error:0.3874, finall error:0.0188
iter:28:40, sigle error:0.3952, finall error:0.0186
iter:29:40, sigle error:0.4018, finall error:0.0193
iter:30:40, sigle error:0.3889, finall error:0.0177
iter:31:40, sigle error:0.3939, finall error:0.0183
iter:32:40, sigle error:0.3838, finall error:0.0182
iter:33:40, sigle error:0.4021, finall error:0.0171
iter:34:40, sigle error:0.4119, finall error:0.0164
iter:35:40, sigle error:0.4093, finall error:0.0164
iter:36:40, sigle error:0.4135, finall error:0.0167
iter:37:40, sigle error:0.4099, finall error:0.0171
iter:38:40, sigle error:0.3871, finall error:0.0163
iter:39:40, sigle error:0.4085, finall error:0.0154
start to test
the accuracy is:97 %
time span: 3777.730945825577

以上就是python实现AdaBoost算法的示例的详细内容,更多关于python实现AdaBoost算法的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现从一组颜色中找出与给定颜色最接近颜色的方法
Mar 19 Python
python简单实现刷新智联简历
Mar 30 Python
Python 爬虫模拟登陆知乎
Sep 23 Python
python解决字符串倒序输出的问题
Jun 25 Python
python脚本开机自启的实现方法
Jun 28 Python
python读写csv文件方法详细总结
Jul 05 Python
python语言中有算法吗
Jun 16 Python
什么是Python包的循环导入
Sep 08 Python
Django如何实现防止XSS攻击
Oct 13 Python
K近邻法(KNN)相关知识总结以及如何用python实现
Jan 28 Python
python爬取youtube视频的示例代码
Mar 03 Python
Python创建SQL数据库流程逐步讲解
Sep 23 Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
Python通过fnmatch模块实现文件名匹配
Sep 30 #Python
Python tempfile模块生成临时文件和临时目录
Sep 30 #Python
Python实现定时监测网站运行状态的示例代码
Sep 30 #Python
python如何实现word批量转HTML
Sep 30 #Python
You might like
PHP内存缓存功能memcached示例
2016/10/19 PHP
ThinkPHP实现分页功能
2017/04/28 PHP
js验证表单大全
2006/11/25 Javascript
javascript getElementsByClassName 和js取地址栏参数
2010/01/02 Javascript
ExtJS 入门
2010/10/29 Javascript
js获取元素外链样式的方法
2015/01/27 Javascript
javascript实现瀑布流自适应遇到的问题及解决方案
2015/01/28 Javascript
jQuery图片特效插件Revealing实现拉伸放大
2015/04/22 Javascript
JQuery中模拟image的ajaxPrefilter与ajaxTransport处理
2015/06/19 Javascript
Javascript技术栈中的四种依赖注入详解
2016/02/23 Javascript
AngularJS基础 ng-keypress 指令简单示例
2016/08/02 Javascript
jquery基本选择器匹配多个元素的实现方法
2016/09/05 Javascript
js实现倒计时及时间对象
2016/11/15 Javascript
网络传输协议(http协议)
2016/11/18 Javascript
jQuery实现动态添加节点与遍历节点功能示例
2017/11/09 jQuery
深入浅析Node.js 事件循环、定时器和process.nextTick()
2018/10/22 Javascript
推荐几个不错的console调试技巧实现
2019/12/20 Javascript
[54:51]Ti4 冒泡赛第二轮LGD vs C9 3
2014/07/14 DOTA
[02:17]DOTA2亚洲邀请赛 RAVE战队出场宣传片
2015/02/07 DOTA
Python Queue模块详细介绍及实例
2016/12/27 Python
python爬虫 正则表达式使用技巧及爬取个人博客的实例讲解
2017/10/20 Python
python实现抽奖小程序
2020/04/15 Python
pyqt5 实现工具栏文字图片同时显示
2019/06/13 Python
Python数据分析pandas模块用法实例详解
2019/11/20 Python
python GUI库图形界面开发之PyQt5工具栏控件QToolBar的详细使用方法与实例
2020/02/28 Python
python openCV实现摄像头获取人脸图片
2020/08/20 Python
HTML5 Canvas如何实现纹理填充与描边(Fill And Stroke)
2013/07/15 HTML / CSS
HTML5 视频播放(video),JavaScript控制视频的实例代码
2018/10/08 HTML / CSS
全球速卖通巴西站点:Aliexpress巴西
2016/08/24 全球购物
美国零售商店:Blue&Cream
2017/04/07 全球购物
捷克家具销售网站:SCONTO Nábytek
2020/01/02 全球购物
娇韵诗俄罗斯官方网站:Clarins俄罗斯
2020/10/03 全球购物
如何查看在weblogic中已经发布的EJB
2012/06/01 面试题
化学系大学生自荐信范文
2014/03/01 职场文书
个人授权委托书范本
2014/04/03 职场文书
nginx代理实现静态资源访问的示例代码
2022/07/07 Servers