Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python原始套接字编程示例分享
Feb 21 Python
Python迭代器和生成器介绍
Mar 06 Python
python 计算文件的md5值实例
Jan 13 Python
Python原始字符串与Unicode字符串操作符用法实例分析
Jul 22 Python
Python数据分析之双色球中蓝红球分析统计示例
Feb 03 Python
django中静态文件配置static的方法
May 20 Python
Python常用爬虫代码总结方便查询
Feb 25 Python
Python OpenCV 使用滑动条来调整函数参数的方法
Jul 08 Python
python中通过selenium简单操作及元素定位知识点总结
Sep 10 Python
关于tf.reverse_sequence()简述
Jan 20 Python
python3 使用openpyxl将mysql数据写入xlsx的操作
May 15 Python
matplotlib绘制多子图共享鼠标光标的方法示例
Jan 08 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
全世界最小的php网页木马一枚 附PHP木马的防范方法
2009/10/09 PHP
简单的方法让你的后台登录更加安全(php中加session验证)
2012/08/22 PHP
比较discuz和ecshop的截取字符串函数php版
2012/09/03 PHP
PHP操作Memcache实例介绍
2013/06/14 PHP
php获取url参数方法总结
2014/11/13 PHP
避免Smarty与CSS语法冲突的方法
2015/03/02 PHP
老生常谈PHP面向对象之标识映射
2017/06/21 PHP
thinkphp分页集成实例
2017/07/24 PHP
PHP数组去重的更快实现方式分析
2018/05/09 PHP
JQuery团队打造的javascript单元测试工具QUnit介绍
2010/02/26 Javascript
每天一篇javascript学习小结(属性定义方法)
2015/11/19 Javascript
深入理解JavaScript程序中内存泄漏
2016/03/17 Javascript
JS验证字符串功能
2017/02/22 Javascript
js学习总结之DOM2兼容处理this问题的解决方法
2017/07/27 Javascript
jQuery滚动条美化插件nicescroll简单用法示例
2018/04/18 jQuery
详解Node.js 中使用 ECDSA 签名遇到的坑
2018/11/26 Javascript
layui 点击重置按钮, select 并没有被重置的解决方法
2019/09/03 Javascript
vue 实现setInterval 创建和销毁实例
2020/07/21 Javascript
[04:47]DOTA2-潍坊风行电子俱乐部探秘
2014/08/08 DOTA
python调用java模块SmartXLS和jpype修改excel文件的方法
2015/04/28 Python
Linux 下 Python 实现按任意键退出的实现方法
2016/09/25 Python
Python标准库06之子进程 (subprocess包) 详解
2016/12/07 Python
python如何读写csv数据
2018/03/21 Python
python获取微信小程序手机号并绑定遇到的坑
2018/11/19 Python
python ipset管理 增删白名单的方法
2019/01/14 Python
美国隐形眼镜零售商:LensPure
2019/03/10 全球购物
GOLFINO英国官网:高尔夫服装
2020/04/11 全球购物
"引用"与指针的区别是什么
2016/09/07 面试题
《记承天寺夜游》教学反思
2014/02/16 职场文书
会务接待方案
2014/02/27 职场文书
慈善晚会策划方案
2014/05/14 职场文书
假释思想汇报范文
2014/10/11 职场文书
开幕式邀请函
2015/01/31 职场文书
只需要这一行代码就能让python计算速度提高十倍
2021/05/24 Python
python中的plt.cm.Paired用法说明
2021/05/31 Python
健身房被搭讪?用python写了个小米计时器助人为乐
2021/06/08 Python