Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python 运算符 供重载参考
Jun 11 Python
python实现获取Ip归属地等信息
Aug 27 Python
python万年历实现代码 含运行结果
May 20 Python
TensorFlow入门使用 tf.train.Saver()保存模型
Apr 24 Python
详解python中init方法和随机数方法
Mar 13 Python
Python 20行简单实现有道在线翻译的详解
May 15 Python
对Python 简单串口收发GUI界面的实例详解
Jun 12 Python
python正则表达式匹配不包含某几个字符的字符串方法
Jul 23 Python
简单了解python中的与或非运算
Sep 18 Python
解决pycharm编辑区显示yaml文件层级结构遇中文乱码问题
Apr 27 Python
Python Socket TCP双端聊天功能实现过程详解
Jun 15 Python
基于Python3读写INI配置文件过程解析
Jul 23 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
PHP 图片文件上传实现代码
2010/12/29 PHP
PHP中的switch语句的用法实例详解
2015/10/21 PHP
一个CSS+jQuery实现的放大缩小动画效果
2014/02/19 Javascript
js生成动态表格并为每个单元格添加单击事件的方法
2014/04/14 Javascript
JavaScript设计模式开发中组合模式的使用教程
2016/05/18 Javascript
利用Javascript实现简单的转盘抽奖
2017/02/13 Javascript
正则 js分转元带千分符号详解
2017/03/08 Javascript
基于JavaScript实现评论框展开和隐藏功能
2017/08/25 Javascript
基于js 各种排序方法和sort方法的区别(详解)
2018/01/03 Javascript
基于百度地图api清除指定覆盖物(Overlay)的方法
2018/01/26 Javascript
微信小程序实现倒计时调用相机自动拍照功能
2018/06/10 Javascript
Bootstrap-table使用footerFormatter做统计列功能
2018/09/07 Javascript
Vue-Cli 3.0 中配置高德地图的两种方式
2019/06/19 Javascript
IDEA安装vue插件图文详解
2019/09/26 Javascript
详解jQuery中的prop()使用方法
2020/01/05 jQuery
JS localStorage存储对象,sessionStorage存储数组对象操作示例
2020/02/15 Javascript
Python是编译运行的验证方法
2015/01/30 Python
Python进程间通信之共享内存详解
2017/10/30 Python
python pytest进阶之xunit fixture详解
2019/06/27 Python
Python实现点云投影到平面显示
2020/01/18 Python
Python发起请求提示UnicodeEncodeError错误代码解决方法
2020/04/21 Python
CSS3打造百度贴吧的3D翻牌效果示例
2017/01/04 HTML / CSS
html5 拖拽上传图片实例演示
2013/04/01 HTML / CSS
捷克玩具商店:Bambule
2019/02/23 全球购物
Etam艾格英国官网:法国著名女装品牌
2019/04/15 全球购物
高三自我鉴定范文
2013/10/19 职场文书
办理退休介绍信
2014/01/09 职场文书
运动会邀请函范文
2014/01/31 职场文书
村干部群众路线教育活动对照检查材料
2014/10/01 职场文书
小学优秀班主任材料
2014/12/17 职场文书
听证通知书
2015/04/24 职场文书
爱心募捐通知范文
2015/04/27 职场文书
2019XX公司员工考核管理制度!
2019/08/07 职场文书
小学生节约用水倡议书
2019/08/12 职场文书
ElementUI实现el-form表单重置功能按钮
2021/07/21 Javascript
纯html+css实现打字效果
2021/08/02 HTML / CSS