Python基于sklearn库的分类算法简单应用示例


Posted in Python onJuly 09, 2018

本文实例讲述了Python基于sklearn库的分类算法简单应用。分享给大家供大家参考,具体如下:

scikit-learn已经包含在Anaconda中。也可以在官方下载源码包进行安装。本文代码里封装了如下机器学习算法,我们修改数据加载函数,即可一键测试:

# coding=gbk
'''
Created on 2016年6月4日
@author: bryan
'''
import time
from sklearn import metrics
import pickle as pickle
import pandas as pd
# Multinomial Naive Bayes Classifier
def naive_bayes_classifier(train_x, train_y):
  from sklearn.naive_bayes import MultinomialNB
  model = MultinomialNB(alpha=0.01)
  model.fit(train_x, train_y)
  return model
# KNN Classifier
def knn_classifier(train_x, train_y):
  from sklearn.neighbors import KNeighborsClassifier
  model = KNeighborsClassifier()
  model.fit(train_x, train_y)
  return model
# Logistic Regression Classifier
def logistic_regression_classifier(train_x, train_y):
  from sklearn.linear_model import LogisticRegression
  model = LogisticRegression(penalty='l2')
  model.fit(train_x, train_y)
  return model
# Random Forest Classifier
def random_forest_classifier(train_x, train_y):
  from sklearn.ensemble import RandomForestClassifier
  model = RandomForestClassifier(n_estimators=8)
  model.fit(train_x, train_y)
  return model
# Decision Tree Classifier
def decision_tree_classifier(train_x, train_y):
  from sklearn import tree
  model = tree.DecisionTreeClassifier()
  model.fit(train_x, train_y)
  return model
# GBDT(Gradient Boosting Decision Tree) Classifier
def gradient_boosting_classifier(train_x, train_y):
  from sklearn.ensemble import GradientBoostingClassifier
  model = GradientBoostingClassifier(n_estimators=200)
  model.fit(train_x, train_y)
  return model
# SVM Classifier
def svm_classifier(train_x, train_y):
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  model.fit(train_x, train_y)
  return model
# SVM Classifier using cross validation
def svm_cross_validation(train_x, train_y):
  from sklearn.grid_search import GridSearchCV
  from sklearn.svm import SVC
  model = SVC(kernel='rbf', probability=True)
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)
  grid_search.fit(train_x, train_y)
  best_parameters = grid_search.best_estimator_.get_params()
  for para, val in list(best_parameters.items()):
    print(para, val)
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)
  model.fit(train_x, train_y)
  return model
def read_data(data_file):
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
if __name__ == '__main__':
  data_file = "H:\\Research\\data\\trainCG.csv"
  thresh = 0.5
  model_save_file = None
  model_save = {}
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT']
  classifiers = {'NB':naive_bayes_classifier,
         'KNN':knn_classifier,
          'LR':logistic_regression_classifier,
          'RF':random_forest_classifier,
          'DT':decision_tree_classifier,
         'SVM':svm_classifier,
        'SVMCV':svm_cross_validation,
         'GBDT':gradient_boosting_classifier
  }
  print('reading training and testing data...')
  train_x, train_y, test_x, test_y = read_data(data_file)
  for classifier in test_classifiers:
    print('******************* %s ********************' % classifier)
    start_time = time.time()
    model = classifiers[classifier](train_x, train_y)
    print('training took %fs!' % (time.time() - start_time))
    predict = model.predict(test_x)
    if model_save_file != None:
      model_save[classifier] = model
    precision = metrics.precision_score(test_y, predict)
    recall = metrics.recall_score(test_y, predict)
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))
    accuracy = metrics.accuracy_score(test_y, predict)
    print('accuracy: %.2f%%' % (100 * accuracy))
  if model_save_file != None:
    pickle.dump(model_save, open(model_save_file, 'wb'))

测试结果如下:

reading training and testing data...
******************* NB ********************
training took 0.004986s!
precision: 78.08%, recall: 71.25%
accuracy: 74.17%
******************* KNN ********************
training took 0.017545s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%
******************* LR ********************
training took 0.061161s!
precision: 89.16%, recall: 92.50%
accuracy: 90.07%
******************* RF ********************
training took 0.040111s!
precision: 96.39%, recall: 100.00%
accuracy: 98.01%
******************* DT ********************
training took 0.004513s!
precision: 96.20%, recall: 95.00%
accuracy: 95.36%
******************* SVM ********************
training took 0.242145s!
precision: 97.53%, recall: 98.75%
accuracy: 98.01%
******************* SVMCV ********************
Fitting 3 folds for each of 14 candidates, totalling 42 fits
[Parallel(n_jobs=1)]: Done  42 out of  42 | elapsed:    6.8s finished
probability True
verbose False
coef0 0.0
degree 3
tol 0.001
shrinking True
cache_size 200
gamma 0.001
max_iter -1
C 1000
decision_function_shape None
random_state None
class_weight None
kernel rbf
training took 7.434668s!
precision: 98.75%, recall: 98.75%
accuracy: 98.68%
******************* GBDT ********************
training took 0.521916s!
precision: 97.56%, recall: 100.00%
accuracy: 98.68%

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
django通过ajax发起请求返回JSON格式数据的方法
Jun 04 Python
Python 实现简单的电话本功能
Aug 09 Python
简单学习Python多进程Multiprocessing
Aug 29 Python
Python对切片命名的实现方法
Oct 16 Python
PyQt5 对图片进行缩放的实例
Jun 18 Python
python使用sessions模拟登录淘宝的方式
Aug 16 Python
Python使用Opencv实现图像特征检测与匹配的方法
Oct 30 Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 Python
Python如何转换字符串大小写
Jun 04 Python
Python爬虫爬取糗事百科段子实例分享
Jul 31 Python
python飞机大战游戏实例讲解
Dec 04 Python
通过Python把学姐照片做成拼图游戏
Feb 15 Python
Python不使用int()函数把字符串转换为数字的方法
Jul 09 #Python
python中ASCII码和字符的转换方法
Jul 09 #Python
python中ASCII码字符与int之间的转换方法
Jul 09 #Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 #Python
python 以16进制打印输出的方法
Jul 09 #Python
python爬虫之urllib3的使用示例
Jul 09 #Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 #Python
You might like
以实例全面讲解PHP中多进程编程的相关函数的使用
2015/08/18 PHP
详解PHP的Yii框架中自带的前端资源包的使用
2016/03/31 PHP
php模拟post上传图片实现代码
2016/06/24 PHP
PHP使用 Pear 进行安装和卸载包的方法详解
2019/07/08 PHP
[JS源码]超长文章自动分页(客户端版)
2007/01/09 Javascript
url地址自动加#号问题说明
2010/08/21 Javascript
原生javascript实现图片轮播效果代码
2010/09/03 Javascript
javascript 学习笔记(八)javascript对象
2011/04/12 Javascript
Javascript表格翻页效果实现思路及代码
2013/08/23 Javascript
js 限制input只能输入数字、字母和汉字等等
2013/12/18 Javascript
jsp网页搜索结果中实现选中一行使其高亮
2014/02/17 Javascript
javascript自定义右键弹出菜单实现方法
2015/05/25 Javascript
学习使用bootstrap3栅格系统
2016/04/12 Javascript
vue.js入门教程之计算属性
2016/09/01 Javascript
js从数组中删除指定值(不是指定位置)的元素实现代码
2016/09/13 Javascript
jQuery Mobile和HTML5开发App推广注册页
2016/11/07 Javascript
深入理解Angularjs中的$resource服务
2016/12/31 Javascript
Angular.js基础学习之初始化
2017/03/10 Javascript
javascript 中的try catch应用总结
2017/04/01 Javascript
微信小程序中使用Async-await方法异步请求变为同步请求方法
2019/03/28 Javascript
基于Web Audio API实现音频可视化效果
2020/06/12 Javascript
python进程管理工具supervisor使用实例
2014/09/17 Python
利用Python画ROC曲线和AUC值计算
2016/09/19 Python
详解Python 2.6 升级至 Python 2.7 的实践心得
2017/04/27 Python
Python函数式编程
2017/07/20 Python
python中利用h5py模块读取h5文件中的主键方法
2018/06/05 Python
Python绘制KS曲线的实现方法
2018/08/13 Python
python实现kNN算法识别手写体数字的示例代码
2019/08/16 Python
opencv python如何实现图像二值化
2020/02/03 Python
Python3 mmap内存映射文件示例解析
2020/03/23 Python
Python用5行代码实现批量抠图的示例代码
2020/04/14 Python
Python单元测试及unittest框架用法实例解析
2020/07/09 Python
瑞典手机壳品牌:Richmond & Finch
2018/04/28 全球购物
企业承诺书格式
2014/05/21 职场文书
幼儿园三八妇女节活动总结
2015/02/06 职场文书
2015年维修工作总结
2015/04/25 职场文书