Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Pthon批量处理将pdb文件生成dssp文件
Jun 21 Python
Python模拟随机游走图形效果示例
Feb 06 Python
Python将图片转换为字符画的方法
Jun 16 Python
python如何使用unittest测试接口
Apr 04 Python
python 基于TCP协议的套接字编程详解
Jun 29 Python
在macOS上搭建python环境的实现方法
Aug 13 Python
利用 Flask 动态展示 Pyecharts 图表数据方法小结
Sep 04 Python
Python中的四种交换数值的方法解析
Nov 18 Python
python游戏开发的五个案例分享
Mar 09 Python
Selenium 滚动页面至元素可见的方法
Mar 18 Python
python如何操作mysql
Aug 17 Python
python实现对doc、txt、xls等文档的读写操作
Apr 02 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
PHP通用检测函数集合
2006/11/25 PHP
PHP 透明水印生成代码
2012/08/27 PHP
PHP5.3的垃圾回收机制(动态存储分配方案)深入理解
2012/12/10 PHP
PHP+Mysql+Ajax+JS实现省市区三级联动
2014/05/23 PHP
ThinkPHP的URL重写问题
2014/06/22 PHP
Yii快速入门经典教程
2015/12/28 PHP
php禁用cookie后session设置方法分析
2016/10/19 PHP
PHP实现求解最长公共子串问题的方法
2017/11/17 PHP
ExtJs使用IFrame的实现代码
2010/03/24 Javascript
JavaScript高级程序设计 XML、Ajax 学习笔记
2011/09/10 Javascript
javascript实现带节日和农历的日历特效
2015/02/01 Javascript
Javascript基于对象三大特性(封装性、继承性、多态性)
2016/01/04 Javascript
Javascript 数组去重的方法(四种)详解及实例代码
2016/11/24 Javascript
微信小程序 自己制作小组件实例详解
2016/12/22 Javascript
jQuery实现动态添加tr到table的方法
2016/12/26 Javascript
javascript设计模式之策略模式学习笔记
2017/02/15 Javascript
Vuex 入门教程
2018/01/10 Javascript
vue中keep-alive的用法及问题描述
2018/05/15 Javascript
JS中的两种数据类型及实现引用类型的深拷贝的方法
2018/08/12 Javascript
对angular 监控数据模型变化的事件方法$watch详解
2018/10/09 Javascript
js变量值传到php过程详解 将php解析成数据
2019/06/26 Javascript
Python脚本实时处理log文件的方法
2016/11/21 Python
Django自定义manage命令实例代码
2018/02/11 Python
pandas数据处理基础之筛选指定行或者指定列的数据
2018/05/03 Python
Python实现的质因式分解算法示例
2018/05/03 Python
HTML5的结构和语义(1):前言
2008/10/17 HTML / CSS
德国奢侈品网上商城:Mytheresa
2016/08/24 全球购物
Dodax奥地利:音乐、电影、书籍、玩具、电子产品等
2019/08/31 全球购物
英国自行车商店:AW Cycles
2021/02/24 全球购物
小学六一儿童节活动方案
2014/08/27 职场文书
计划生育证明格式范本
2014/09/12 职场文书
2014年学校卫生工作总结
2014/11/20 职场文书
费用申请报告范文
2015/05/15 职场文书
国庆节新闻稿
2015/07/17 职场文书
Redis基于Bitmap实现用户签到功能
2021/06/20 Redis
Node.js实现断点续传
2021/06/23 Javascript