Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
简单介绍Python中的struct模块
Apr 28 Python
python处理xml文件的方法小结
May 02 Python
浅谈Python2获取中文文件名的编码问题
Jan 09 Python
Python装饰器(decorator)定义与用法详解
Feb 09 Python
Python实现去除列表中重复元素的方法小结【4种方法】
Apr 27 Python
python实现requests发送/上传多个文件的示例
Jun 04 Python
在pycharm中python切换解释器失败的解决方法
Oct 29 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
Aug 07 Python
浅谈Pytorch torch.optim优化器个性化的使用
Feb 20 Python
windows支持哪个版本的python
Jul 03 Python
python excel多行合并的方法
Dec 09 Python
一个入门级python爬虫教程详解
Jan 27 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
snoopy PHP版的网络客户端提供本地下载
2008/04/15 PHP
php 图像函数大举例(非原创)
2009/06/20 PHP
解决phpcms更换javascript的幻灯片代码调用图片问题
2014/12/26 PHP
YII2.0框架行为(Behavior)深入详解
2019/07/26 PHP
jquery 简短右键菜单 多浏览器兼容
2010/01/01 Javascript
两个listbox实现选项的添加删除和搜索
2013/03/01 Javascript
JS中setTimeout()的用法详解
2013/04/14 Javascript
Node.js中AES加密和其它语言不一致问题解决办法
2014/03/10 Javascript
JavaScript知识点总结(五)之Javascript中两个等于号(==)和三个等于号(===)的区别
2016/05/31 Javascript
JavaScript中Form表单技术汇总(推荐)
2016/06/26 Javascript
jQuery基本过滤选择器用法示例
2016/09/09 Javascript
AngularJS ng-style中使用filter
2016/09/21 Javascript
jQuery给指定的table动态添加删除行的操作方法
2016/10/12 Javascript
Radio 单选JS动态添加的选项onchange事件无效的解决方法
2016/12/12 Javascript
JS使用ActiveXObject实现用户提交表单时屏蔽敏感词功能
2017/06/20 Javascript
Vue render深入开发讲解
2018/04/13 Javascript
详解npm 配置项registry修改为淘宝镜像
2018/09/07 Javascript
[03:24]DOTA2超级联赛专访hao 大翻盘就是逆袭
2013/05/24 DOTA
Pytest框架之fixture的详细使用教程
2020/04/07 Python
使用Keras构造简单的CNN网络实例
2020/06/29 Python
浅谈keras中Dropout在预测过程中是否仍要起作用
2020/07/09 Python
莫斯科隐形眼镜网上商店:Linzi
2019/07/22 全球购物
.net软件工程师面试题
2015/03/31 面试题
高三体育教学反思
2014/01/29 职场文书
设计顾问服务计划书
2014/05/04 职场文书
电子商务专业应届毕业生求职信
2014/06/21 职场文书
团日活动总结怎么写
2014/06/25 职场文书
2014年秋季开学典礼致辞
2014/08/02 职场文书
初中生300字旷课检讨书
2014/11/19 职场文书
少先队工作总结2015
2015/05/13 职场文书
学校教代会开幕词
2016/03/04 职场文书
python opencv通过4坐标剪裁图片
2021/06/05 Python
MySQL窗口函数的具体使用
2021/11/17 MySQL
68行Python代码实现带难度升级的贪吃蛇
2022/01/18 Python
Python实现文字pdf转换图片pdf效果
2022/04/03 Python
在NumPy中深拷贝和浅拷贝相关操作的定义和背后的原理
2022/04/14 Python