Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python django 访问静态文件出现404或500错误
Jan 20 Python
mac安装scrapy并创建项目的实例讲解
Jun 13 Python
Python中return self的用法详解
Jul 27 Python
解决Shell执行python文件,传参空格引起的问题
Oct 30 Python
Python制作exe文件简单流程
Jan 24 Python
Django框架会话技术实例分析【Cookie与Session】
May 24 Python
python多线程案例之多任务copy文件完整实例
Oct 29 Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 Python
Python如何使用bokeh包和geojson数据绘制地图
Mar 21 Python
jupyter notebook运行命令显示[*](解决办法)
May 18 Python
一个入门级python爬虫教程详解
Jan 27 Python
python数据库批量插入数据的实现(executemany的使用)
Apr 30 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
php多维数组去掉重复值示例分享
2014/03/02 PHP
php多重接口的实现方法
2015/06/20 PHP
又十个超级有用的PHP代码片段
2015/09/24 PHP
php源码之将图片转化为data/base64数据流实例详解
2016/11/27 PHP
Yii框架实现记录日志到自定义文件的方法
2017/05/23 PHP
JavaScript Konami Code 实现代码
2009/07/29 Javascript
javascript真的不难-回顾一下基础知识
2013/01/15 Javascript
Javascript 正则表达式实现为数字添加千位分隔符
2015/03/10 Javascript
yui3的AOP(面向切面编程)和OOP(面向对象编程)
2015/05/01 Javascript
js实现YouKu的漂亮搜索框效果
2015/08/19 Javascript
AngularJS教程之环境设置
2016/08/16 Javascript
JS焦点图,JS 多个页面放多个焦点图的实例
2016/12/08 Javascript
理解nodejs的stream和pipe机制的原理和实现
2017/08/12 NodeJs
iview table高度动态设置方法
2018/03/14 Javascript
JavaScript中工厂函数与构造函数示例详解
2019/05/06 Javascript
Openlayers实现图形绘制
2020/09/28 Javascript
基于ajax实现上传图片代码示例解析
2020/12/03 Javascript
python list中append()与extend()用法分享
2013/03/24 Python
详解python之简单主机批量管理工具
2017/01/27 Python
Python Pillow Image Invert
2019/01/22 Python
python opencv判断图像是否为空的实例
2019/01/26 Python
详解python配置虚拟环境
2019/04/08 Python
Python爬虫实现的根据分类爬取豆瓣电影信息功能示例
2019/09/15 Python
Python3直接爬取图片URL并保存示例
2019/12/18 Python
python 爬取马蜂窝景点翻页文字评论的实现
2020/01/20 Python
python 已知一个字符,在一个list中找出近似值或相似值实现模糊匹配
2020/02/29 Python
Python PyQt5模块实现窗口GUI界面代码实例
2020/05/12 Python
django前端页面下拉选择框默认值设置方式
2020/08/09 Python
Python Map 函数的使用
2020/08/28 Python
python 递归相关知识总结
2021/03/03 Python
英国领先的汽车轮胎和快速健康中心:Kwik Fit
2017/10/29 全球购物
生产主管岗位职责
2013/11/10 职场文书
反四风个人对照检查材料
2014/09/26 职场文书
七年级地理教学计划
2015/01/22 职场文书
手术室护士个人总结
2015/02/13 职场文书
年度考核个人总结
2015/03/06 职场文书