Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python 正则表达式(转义问题)
Dec 15 Python
Python itertools模块详解
May 09 Python
Python编程之基于概率论的分类方法:朴素贝叶斯
Nov 11 Python
Python缓存技术实现过程详解
Sep 25 Python
导入tensorflow时报错:cannot import name 'abs'的解决
Oct 10 Python
在python中利用try..except来代替if..else的用法
Dec 19 Python
python tkinter之顶层菜单、弹出菜单实例
Mar 04 Python
python 如何利用argparse解析命令行参数
Sep 11 Python
Python识别处理照片中的条形码
Nov 16 Python
Python读取pdf表格写入excel的方法
Jan 22 Python
python基础学习之递归函数知识总结
May 26 Python
opencv深入浅出了解机器学习和深度学习
Mar 17 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
PHP4之真OO
2006/10/09 PHP
漂亮但不安全的CTB
2006/10/09 PHP
浅析PHP的ASCII码转换类
2013/07/05 PHP
在PHP模板引擎smarty生成随机数的方法和math函数详解
2014/04/24 PHP
php中使用Ajax时出现Error(c00ce56e)的详细解决方案
2014/11/03 PHP
Laravel中间件实现原理详解
2016/10/09 PHP
php使用redis的几种常见操作方式和用法示例
2020/02/20 PHP
让插入到 innerHTML 中的 script 跑起来的实现代码
2006/07/01 Javascript
jQuery-onload让第一次页面加载时图片是淡入方式显示
2012/05/23 Javascript
jquery实现弹出层效果实例
2015/05/19 Javascript
jQuery.trim() 函数及trim()用法详解
2015/10/26 Javascript
JavaScript实现打开链接页面的方式汇总
2016/06/02 Javascript
深入浅析JS的数组遍历方法(推荐)
2016/06/15 Javascript
Bootstrap中文本框的宽度变窄并且加入一副验证码图片的实现方法
2016/06/23 Javascript
基于CSS3和jQuery实现跟随鼠标方位的Hover特效
2016/07/25 Javascript
javascript轮播图算法
2016/10/21 Javascript
浅谈Vue.js
2017/03/02 Javascript
Bootstrap免费字体和图标网站(值得收藏)
2017/03/16 Javascript
基于JavaScript实现活动倒计时效果
2017/04/20 Javascript
NodeJs安装npm包一直失败的解决方法
2017/04/28 NodeJs
vue.js 实现点击按钮动态添加li的方法
2018/09/07 Javascript
浅谈JavaScript中this的指向更改
2020/07/28 Javascript
python目录操作之python遍历文件夹后将结果存储为xml
2014/01/27 Python
Python with的用法
2014/08/22 Python
python bottle框架支持jquery ajax的RESTful风格的PUT和DELETE方法
2017/05/24 Python
Python带动态参数功能的sqlite工具类
2018/05/26 Python
python多线程分块读取文件
2019/08/29 Python
Python Numpy,mask图像的生成详解
2020/02/19 Python
利用CSS3实现开门效果实例源码
2016/08/22 HTML / CSS
英国领先的运动物理治疗供应公司:Vivomed
2018/07/14 全球购物
波兰办公用品和学校用品在线商店:Dlabiura24.pl
2020/11/18 全球购物
禁止高声喧哗的标语
2014/06/11 职场文书
党员个人对照检查材料
2014/10/01 职场文书
2014年社区工会工作总结
2014/12/18 职场文书
浅谈css实现背景颜色半透明的两种方法
2021/12/06 HTML / CSS
B站评分公认最好看的动漫,你的名字评分9.9,第六备受喜欢
2022/03/18 日漫