python实现AdaBoost算法的示例


Posted in Python onOctober 03, 2020

代码

'''
数据集:Mnist
训练集数量:60000(实际使用:10000)
测试集数量:10000(实际使用:1000)
层数:40
------------------------------
运行结果:
  正确率:97%
  运行时长:65m
'''

import time
import numpy as np


def loadData(fileName):
  '''
  加载文件
  :param fileName:要加载的文件路径
  :return: 数据集和标签集
  '''
  # 存放数据及标记
  dataArr = []
  labelArr = []
  # 读取文件
  fr = open(fileName)
  # 遍历文件中的每一行
  for line in fr.readlines():
    # 获取当前行,并按“,”切割成字段放入列表中
    # strip:去掉每行字符串首尾指定的字符(默认空格或换行符)
    # split:按照指定的字符将字符串切割成每个字段,返回列表形式
    curLine = line.strip().split(',')
    # 将每行中除标记外的数据放入数据集中(curLine[0]为标记信息)
    # 在放入的同时将原先字符串形式的数据转换为整型
    # 此外将数据进行了二值化处理,大于128的转换成1,小于的转换成0,方便后续计算
    dataArr.append([int(int(num) > 128) for num in curLine[1:]])
    # 将标记信息放入标记集中
    # 放入的同时将标记转换为整型

    # 转换成二分类任务
    # 标签0设置为1,反之为-1
    if int(curLine[0]) == 0:
      labelArr.append(1)
    else:
      labelArr.append(-1)
  # 返回数据集和标记
  return dataArr, labelArr


def calc_e_Gx(trainDataArr, trainLabelArr, n, div, rule, D):
  '''
  计算分类错误率
  :param trainDataArr:训练数据集数字
  :param trainLabelArr: 训练标签集数组
  :param n: 要操作的特征
  :param div:划分点
  :param rule:正反例标签
  :param D:权值分布D
  :return:预测结果, 分类误差率
  '''
  # 初始化分类误差率为0
  e = 0
  # 将训练数据矩阵中特征为n的那一列单独剥出来做成数组。因为其他元素我们并不需要,
  # 直接对庞大的训练集进行操作的话会很慢
  x = trainDataArr[:, n]
  # 同样将标签也转换成数组格式,x和y的转换只是单纯为了提高运行速度
  # 测试过相对直接操作而言性能提升很大
  y = trainLabelArr
  predict = []

  # 依据小于和大于的标签依据实际情况会不同,在这里直接进行设置
  if rule == 'LisOne':
    L = 1
    H = -1
  else:
    L = -1
    H = 1

  # 遍历所有样本的特征m
  for i in range(trainDataArr.shape[0]):
    if x[i] < div:
      # 如果小于划分点,则预测为L
      # 如果设置小于div为1,那么L就是1,
      # 如果设置小于div为-1,L就是-1
      predict.append(L)
      # 如果预测错误,分类错误率要加上该分错的样本的权值(8.1式)
      if y[i] != L:
        e += D[i]
    elif x[i] >= div:
      # 与上面思想一样
      predict.append(H)
      if y[i] != H:
        e += D[i]
  # 返回预测结果和分类错误率e
  # 预测结果其实是为了后面做准备的,在算法8.1第四步式8.4中exp内部有个Gx,要用在那个地方
  # 以此来更新新的D
  return np.array(predict), e


def createSigleBoostingTree(trainDataArr, trainLabelArr, D):
  '''
  创建单层提升树
  :param trainDataArr:训练数据集数组
  :param trainLabelArr: 训练标签集数组
  :param D: 算法8.1中的D
  :return: 创建的单层提升树
  '''

  # 获得样本数目及特征数量
  m, n = np.shape(trainDataArr)
  # 单层树的字典,用于存放当前层提升树的参数
  # 也可以认为该字典代表了一层提升树
  sigleBoostTree = {}
  # 初始化分类误差率,分类误差率在算法8.1步骤(2)(b)有提到
  # 误差率最高也只能100%,因此初始化为1
  sigleBoostTree['e'] = 1

  # 对每一个特征进行遍历,寻找用于划分的最合适的特征
  for i in range(n):
    # 因为特征已经经过二值化,只能为0和1,因此分切分时分为-0.5, 0.5, 1.5三挡进行切割
    for div in [-0.5, 0.5, 1.5]:
      # 在单个特征内对正反例进行划分时,有两种情况:
      # 可能是小于某值的为1,大于某值得为-1,也可能小于某值得是-1,反之为1
      # 因此在寻找最佳提升树的同时对于两种情况也需要遍历运行
      # LisOne:Low is one:小于某值得是1
      # HisOne:High is one:大于某值得是1
      for rule in ['LisOne', 'HisOne']:
        # 按照第i个特征,以值div进行切割,进行当前设置得到的预测和分类错误率
        Gx, e = calc_e_Gx(trainDataArr, trainLabelArr, i, div, rule, D)
        # 如果分类错误率e小于当前最小的e,那么将它作为最小的分类错误率保存
        if e < sigleBoostTree['e']:
          sigleBoostTree['e'] = e
          # 同时也需要存储最优划分点、划分规则、预测结果、特征索引
          # 以便进行D更新和后续预测使用
          sigleBoostTree['div'] = div
          sigleBoostTree['rule'] = rule
          sigleBoostTree['Gx'] = Gx
          sigleBoostTree['feature'] = i
  # 返回单层的提升树
  return sigleBoostTree


def createBosstingTree(trainDataList, trainLabelList, treeNum=50):
  '''
  创建提升树
  创建算法依据“8.1.2 AdaBoost算法” 算法8.1
  :param trainDataList:训练数据集
  :param trainLabelList: 训练测试集
  :param treeNum: 树的层数
  :return: 提升树
  '''
  # 将数据和标签转化为数组形式
  trainDataArr = np.array(trainDataList)
  trainLabelArr = np.array(trainLabelList)
  # 没增加一层数后,当前最终预测结果列表
  finallpredict = [0] * len(trainLabelArr)
  # 获得训练集数量以及特征个数
  m, n = np.shape(trainDataArr)

  # 依据算法8.1步骤(1)初始化D为1/N
  D = [1 / m] * m
  # 初始化提升树列表,每个位置为一层
  tree = []
  # 循环创建提升树
  for i in range(treeNum):
    # 得到当前层的提升树
    curTree = createSigleBoostingTree(trainDataArr, trainLabelArr, D)
    # 根据式8.2计算当前层的alpha
    alpha = 1 / 2 * np.log((1 - curTree['e']) / curTree['e'])
    # 获得当前层的预测结果,用于下一步更新D
    Gx = curTree['Gx']
    # 依据式8.4更新D
    # 考虑到该式每次只更新D中的一个w,要循环进行更新知道所有w更新结束会很复杂(其实
    # 不是时间上的复杂,只是让人感觉每次单独更新一个很累),所以该式以向量相乘的形式,
    # 一个式子将所有w全部更新完。
    # 该式需要线性代数基础,如果不太熟练建议补充相关知识,当然了,单独更新w也一点问题
    # 没有
    # np.multiply(trainLabelArr, Gx):exp中的y*Gm(x),结果是一个行向量,内部为yi*Gm(xi)
    # np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx)):上面求出来的行向量内部全体
    # 成员再乘以-αm,然后取对数,和书上式子一样,只不过书上式子内是一个数,这里是一个向量
    # D是一个行向量,取代了式中的wmi,然后D求和为Zm
    # 书中的式子最后得出来一个数w,所有数w组合形成新的D
    # 这里是直接得到一个向量,向量内元素是所有的w
    # 本质上结果是相同的
    D = np.multiply(D, np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx))) / sum(D)
    # 在当前层参数中增加alpha参数,预测的时候需要用到
    curTree['alpha'] = alpha
    # 将当前层添加到提升树索引中。
    tree.append(curTree)

    # -----以下代码用来辅助,可以去掉---------------
    # 根据8.6式将结果加上当前层乘以α,得到目前的最终输出预测
    finallpredict += alpha * Gx
    # 计算当前最终预测输出与实际标签之间的误差
    error = sum([1 for i in range(len(trainDataList)) if np.sign(finallpredict[i]) != trainLabelArr[i]])
    # 计算当前最终误差率
    finallError = error / len(trainDataList)
    # 如果误差为0,提前退出即可,因为没有必要再计算算了
    if finallError == 0:
      return tree
    # 打印一些信息
    print('iter:%d:%d, sigle error:%.4f, finall error:%.4f' % (i, treeNum, curTree['e'], finallError))
  # 返回整个提升树
  return tree


def predict(x, div, rule, feature):
  '''
  输出单独层预测结果
  :param x: 预测样本
  :param div: 划分点
  :param rule: 划分规则
  :param feature: 进行操作的特征
  :return:
  '''
  # 依据划分规则定义小于及大于划分点的标签
  if rule == 'LisOne':
    L = 1
    H = -1
  else:
    L = -1
    H = 1

  # 判断预测结果
  if x[feature] < div:
    return L
  else:
    return H


def test(testDataList, testLabelList, tree):
  '''
  测试
  :param testDataList:测试数据集
  :param testLabelList: 测试标签集
  :param tree: 提升树
  :return: 准确率
  '''
  # 错误率计数值
  errorCnt = 0
  # 遍历每一个测试样本
  for i in range(len(testDataList)):
    # 预测结果值,初始为0
    result = 0
    # 依据算法8.1式8.6
    # 预测式子是一个求和式,对于每一层的结果都要进行一次累加
    # 遍历每层的树
    for curTree in tree:
      # 获取该层参数
      div = curTree['div']
      rule = curTree['rule']
      feature = curTree['feature']
      alpha = curTree['alpha']
      # 将当前层结果加入预测中
      result += alpha * predict(testDataList[i], div, rule, feature)
    # 预测结果取sign值,如果大于0 sign为1,反之为0
    if np.sign(result) != testLabelList[i]: 
      errorCnt += 1
  # 返回准确率
  return 1 - errorCnt / len(testDataList)


if __name__ == '__main__':
  # 开始时间
  start = time.time()

  # 获取训练集
  print('start read transSet')
  trainDataList, trainLabelList = loadData('../Mnist/mnist_train.csv')

  # 获取测试集
  print('start read testSet')
  testDataList, testLabelList = loadData('../Mnist/mnist_test.csv')

  # 创建提升树
  print('start init train')
  tree = createBosstingTree(trainDataList[:10000], trainLabelList[:10000], 40)

  # 测试
  print('start to test')
  accuracy = test(testDataList[:1000], testLabelList[:1000], tree)
  print('the accuracy is:%d' % (accuracy * 100), '%')

  # 结束时间
  end = time.time()
  print('time span:', end - start)

程序运行结果

start read transSet
start read testSet
start init train
iter:0:40, sigle error:0.0804, finall error:0.0804
iter:1:40, sigle error:0.1448, finall error:0.0804
iter:2:40, sigle error:0.1362, finall error:0.0585
iter:3:40, sigle error:0.1864, finall error:0.0667
iter:4:40, sigle error:0.2249, finall error:0.0474
iter:5:40, sigle error:0.2634, finall error:0.0437
iter:6:40, sigle error:0.2626, finall error:0.0377
iter:7:40, sigle error:0.2935, finall error:0.0361
iter:8:40, sigle error:0.3230, finall error:0.0333
iter:9:40, sigle error:0.3034, finall error:0.0361
iter:10:40, sigle error:0.3375, finall error:0.0325
iter:11:40, sigle error:0.3364, finall error:0.0340
iter:12:40, sigle error:0.3473, finall error:0.0309
iter:13:40, sigle error:0.3006, finall error:0.0294
iter:14:40, sigle error:0.3267, finall error:0.0275
iter:15:40, sigle error:0.3584, finall error:0.0288
iter:16:40, sigle error:0.3492, finall error:0.0257
iter:17:40, sigle error:0.3506, finall error:0.0256
iter:18:40, sigle error:0.3665, finall error:0.0240
iter:19:40, sigle error:0.3769, finall error:0.0251
iter:20:40, sigle error:0.3828, finall error:0.0213
iter:21:40, sigle error:0.3733, finall error:0.0229
iter:22:40, sigle error:0.3785, finall error:0.0218
iter:23:40, sigle error:0.3867, finall error:0.0219
iter:24:40, sigle error:0.3850, finall error:0.0208
iter:25:40, sigle error:0.3823, finall error:0.0201
iter:26:40, sigle error:0.3825, finall error:0.0204
iter:27:40, sigle error:0.3874, finall error:0.0188
iter:28:40, sigle error:0.3952, finall error:0.0186
iter:29:40, sigle error:0.4018, finall error:0.0193
iter:30:40, sigle error:0.3889, finall error:0.0177
iter:31:40, sigle error:0.3939, finall error:0.0183
iter:32:40, sigle error:0.3838, finall error:0.0182
iter:33:40, sigle error:0.4021, finall error:0.0171
iter:34:40, sigle error:0.4119, finall error:0.0164
iter:35:40, sigle error:0.4093, finall error:0.0164
iter:36:40, sigle error:0.4135, finall error:0.0167
iter:37:40, sigle error:0.4099, finall error:0.0171
iter:38:40, sigle error:0.3871, finall error:0.0163
iter:39:40, sigle error:0.4085, finall error:0.0154
start to test
the accuracy is:97 %
time span: 3777.730945825577

以上就是python实现AdaBoost算法的示例的详细内容,更多关于python实现AdaBoost算法的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现带验证码网站的自动登陆实现代码
Jan 12 Python
一个基于flask的web应用诞生 用户注册功能开发(5)
Apr 11 Python
Python通过future处理并发问题
Oct 17 Python
Python批量更改文件名的实现方法
Oct 29 Python
python利用sklearn包编写决策树源代码
Dec 21 Python
set在python里的含义和用法
Jun 24 Python
django框架实现模板中获取request 的各种信息示例
Jul 01 Python
scikit-learn线性回归,多元回归,多项式回归的实现
Aug 29 Python
python加载自定义词典实例
Dec 06 Python
python解释器pycharm安装及环境变量配置教程图文详解
Feb 26 Python
Python爬虫爬取、解析数据操作示例
Mar 27 Python
python析构函数用法及注意事项
Jun 22 Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
Python通过fnmatch模块实现文件名匹配
Sep 30 #Python
Python tempfile模块生成临时文件和临时目录
Sep 30 #Python
Python实现定时监测网站运行状态的示例代码
Sep 30 #Python
python如何实现word批量转HTML
Sep 30 #Python
You might like
让PHP支持页面回退的两种方法
2008/01/10 PHP
PHP中file_exists使用中遇到的问题小结
2016/04/05 PHP
Thinkphp5行为使用方法汇总
2017/12/21 PHP
php删除二维数组中的重复值方法
2018/03/12 PHP
Thinkphp5.0 框架Model模型简单用法分析
2019/10/11 PHP
JavaScript 闭包在封装函数时的简单分析
2009/11/28 Javascript
jQuery 中使用JSON的实现代码
2011/12/01 Javascript
js 金额文本框实现代码
2012/02/14 Javascript
JavaScript解析json格式数据简单示例
2014/12/09 Javascript
封装属于自己的JS组件
2016/01/27 Javascript
javascript拖拽应用实例(二)
2016/03/25 Javascript
JavaScript中动态向表格添加数据
2017/01/24 Javascript
Bootstrap提示框效果的实例代码
2017/07/12 Javascript
微信小程序中button组件的边框设置的实例详解
2017/09/27 Javascript
JS获取input[file]的值并显示在页面的实现方法
2018/03/09 Javascript
基于Vue实现可以拖拽的树形表格实例详解
2018/10/18 Javascript
react-router 路由切换动画的实现示例
2018/12/03 Javascript
详解angularjs4部署文件过大解决过程
2018/12/05 Javascript
微信小程序自定义组件实现环形进度条
2020/11/17 Javascript
Vue管理系统前端之组件拆分封装详解
2020/08/23 Javascript
深入解答关于Python的11道基本面试题
2017/04/01 Python
Python实现OpenCV的安装与使用示例
2018/03/30 Python
Python实现操纵控制windows注册表的方法分析
2019/05/24 Python
django echarts饼图数据动态加载的实例
2019/08/12 Python
OpenCV 表盘指针自动读数的示例代码
2020/04/10 Python
python matplotlib.pyplot.plot()参数用法
2020/04/14 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
2020/06/05 Python
python中strip(),lstrip(),rstrip()函数的使用讲解
2020/11/17 Python
HTML5 Canvas中绘制椭圆的4种方法
2015/04/24 HTML / CSS
乌克兰第一的珠宝网上商店:Gold.ua
2019/11/29 全球购物
临床护理求职信
2014/04/26 职场文书
读书活动总结
2014/04/28 职场文书
优秀语文教师事迹
2014/05/18 职场文书
理想演讲稿范文
2014/05/21 职场文书
食品安全责任书范本
2015/05/09 职场文书
幼儿园开学家长寄语(2016秋季)
2015/12/03 职场文书