Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
简单介绍Python中的filter和lambda函数的使用
Apr 07 Python
在Python中使用判断语句和循环的教程
Apr 25 Python
python基础教程之分支、循环简单用法
Jun 16 Python
python中如何正确使用正则表达式的详细模式(Verbose mode expression)
Nov 08 Python
python numpy格式化打印的实例
May 14 Python
python3 面向对象__类的内置属性与方法的实例代码
Nov 09 Python
python实现石头剪刀布程序
Jan 20 Python
python使用time、datetime返回工作日列表实例代码
May 09 Python
Django Rest framework三种分页方式详解
Jul 26 Python
Python 实用技巧之利用Shell通配符做字符串匹配
Aug 23 Python
PyCharm中关于安装第三方包的三个建议
Sep 17 Python
使用sublime text3搭建Python编辑环境的实现
Jan 12 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
PHP中cookies使用指南
2007/03/16 PHP
PHP中PDO基础教程 入门级
2011/09/04 PHP
php数组函数序列之in_array() 查找数组值是否存在
2011/10/29 PHP
PHP目录操作实例总结
2016/09/27 PHP
laravel5.6 框架邮件队列database驱动简单demo示例
2020/01/26 PHP
JScript 脚本实现文件下载 一般用于下载木马
2009/10/29 Javascript
jquery EasyUI的formatter格式化函数代码
2011/01/12 Javascript
你必须知道的Javascript知识点之"this指针"的应用
2013/04/23 Javascript
JavaScript 定时器 SetTimeout之定时刷新窗口和关闭窗口(代码超简单)
2016/02/26 Javascript
jQuery点击其他地方时菜单消失的实现方法
2016/04/22 Javascript
js调用webservice构造SOAP进行身份验证
2016/04/27 Javascript
JavaScript入门教程之引用类型
2016/05/04 Javascript
jQuery Ajax实现跨域请求
2017/01/21 Javascript
NodeJS仿WebApi路由示例
2017/02/28 NodeJs
js for循环倒序输出数组元素的实例
2017/03/01 Javascript
详解vue.js移动端导航navigationbar的封装
2017/07/05 Javascript
深入探究angular2 UI组件之primeNG用法
2017/07/26 Javascript
微信小程序 POST请求的实例详解
2017/09/29 Javascript
在knockoutjs 上自己实现的flux(实例讲解)
2017/12/18 Javascript
Angular2 父子组件通信方式的示例
2018/01/29 Javascript
JS 实现微信扫一扫功能
2018/09/14 Javascript
JavaScript canvas绘制折线图
2020/02/18 Javascript
numpy.std() 计算矩阵标准差的方法
2018/07/11 Python
浅述python2与python3的简单区别
2018/09/19 Python
Python+OpenCV实现图像融合的原理及代码
2018/12/03 Python
python引用(import)某个模块提示没找到对应模块的解决方法
2019/01/19 Python
对python 中re.sub,replace(),strip()的区别详解
2019/07/22 Python
Python json模块与jsonpath模块区别详解
2020/03/05 Python
Python尾递归优化实现代码及原理详解
2020/10/09 Python
迪拜航空官方网站:flydubai
2017/04/20 全球购物
传播学毕业生求职信
2013/10/11 职场文书
区优秀教师事迹材料
2014/02/10 职场文书
党员带头倡议书
2015/04/29 职场文书
2015年质量管理工作总结范文
2015/05/18 职场文书
javascript实现计算器功能详解流程
2021/11/01 Javascript
ICOM R71E和R72E图文对比解说
2022/04/07 无线电