python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
400多行Python代码实现了一个FTP服务器
May 10 Python
举例介绍Python中的25个隐藏特性
Mar 30 Python
Python3实现简单可学习的手写体识别(实例讲解)
Oct 21 Python
python破解zip加密文件的方法
May 31 Python
Python实现将HTML转成PDF的方法分析
May 04 Python
Python实现的统计文章单词次数功能示例
Jul 08 Python
wxPython之wx.DC绘制形状
Nov 19 Python
python调用接口的4种方式代码实例
Nov 19 Python
Jupyter Notebook的连接密码 token查询方式
Apr 21 Python
Django表单提交后实现获取相同name的不同value值
May 14 Python
tensorflow中的梯度求解及梯度裁剪操作
May 26 Python
Python 中的单分派泛函数你真的了解吗
Jun 22 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
YB217、YB235、YB400浅听
2021/03/02 无线电
在任意字符集下正常显示网页的方法一
2007/04/01 PHP
PHP伪造referer实例代码
2008/09/20 PHP
PHP程序员常见的40个陋习,你中了几个?
2014/11/20 PHP
php查看当前Session的ID实例
2015/03/16 PHP
php实现在服务器上创建目录的方法
2015/03/16 PHP
PHP中把对象转换为关联数组代码分享
2015/04/09 PHP
用javascript自动显示最后更新时间
2007/03/15 Javascript
JS显示下拉列表框内全部元素的方法
2015/03/31 Javascript
JQuery日历插件My97DatePicker日期范围限制
2016/01/20 Javascript
JS实现的表头列头固定页面功能示例
2017/01/10 Javascript
js 原型对象和原型链理解
2017/02/09 Javascript
vue.js删除动态绑定的radio的指定项
2017/06/02 Javascript
bootstrap table实现点击翻页功能 可记录上下页选中的行
2017/09/28 Javascript
JS组件系列之Gojs组件 前端图形化插件之利器
2017/11/29 Javascript
Angular6新特性之Angular Material
2018/12/28 Javascript
使用VUE+iView+.Net Core上传图片的方法示例
2019/01/04 Javascript
JavaScript实现的九种排序算法
2019/03/04 Javascript
Vue实例的对象参数options的几个常用选项详解
2019/11/08 Javascript
angula中使用iframe点击后不执行变更检测的问题
2020/05/10 Javascript
python概率计算器实例分析
2015/03/25 Python
Linux下python3.7.0安装教程
2018/07/30 Python
python斐波那契数列的计算方法
2018/09/27 Python
python常用函数与用法示例
2019/07/02 Python
python返回数组的索引实例
2019/11/28 Python
什么是SQL Server的确定性函数和不确定性函数
2016/08/04 面试题
Java语言程序设计测试题判断题部分
2013/01/06 面试题
ruby如何进行集成操作?Ruby能进行多重继承吗?
2013/10/16 面试题
给老师的道歉信
2014/01/11 职场文书
班级道德讲堂实施方案
2014/02/24 职场文书
安全责任协议书
2014/04/21 职场文书
关于雷锋的演讲稿
2014/05/10 职场文书
理发店策划方案
2014/06/05 职场文书
四风剖析查摆对照检查材料思想汇报
2014/09/24 职场文书
销售员自我评价
2015/03/11 职场文书
JS 4个超级实用的小技巧 提升开发效率
2021/10/05 Javascript