python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python和ruby,我选谁?
Sep 13 Python
Python简单实现查找一个字符串中最长不重复子串的方法
Mar 26 Python
对python Tkinter Text的用法详解
Oct 11 Python
浅析Python 读取图像文件的性能对比
Mar 07 Python
python+webdriver自动化环境搭建步骤详解
Jun 03 Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 Python
应用OpenCV和Python进行SIFT算法的实现详解
Aug 21 Python
python的mysql数据库建立表与插入数据操作示例
Sep 30 Python
python装饰器代替set get方法实例
Dec 19 Python
Python基于stuck实现scoket文件传输
Apr 02 Python
为什么说python适合写爬虫
Jun 11 Python
Python代码,能玩30多款童年游戏!这些有几个是你玩过的
Apr 27 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
php 之 没有mysql支持时的替代方案
2006/10/09 PHP
php下图片文字混合水印与缩略图实现代码
2009/12/11 PHP
php imagecreatetruecolor 创建高清和透明图片代码小结
2010/05/15 PHP
php ci框架中加载css和js文件失败的解决方法
2014/03/03 PHP
根据key删除数组中指定的元素实现方法
2017/03/02 PHP
SWFObject 2.1以上版本语法介绍
2010/07/10 Javascript
超越Jquery_01_isPlainObject分析与重构
2010/10/20 Javascript
js查找某元素中的所有图片地址的方法
2014/01/16 Javascript
css与javascript跨浏览器兼容性总结
2014/09/15 Javascript
Internet Explorer 11 浏览器介绍:别叫我IE
2014/09/28 Javascript
js实现select跳转菜单新窗口效果代码分享(超简单)
2015/08/21 Javascript
JavaScript ES6中CLASS的使用详解
2016/11/22 Javascript
微信小程序 setData使用方法及常用错误解决办法
2017/05/11 Javascript
详解前端路由实现与react-router使用姿势
2017/08/07 Javascript
JavaScript调用模式与this关键字绑定的关系
2018/04/21 Javascript
node前端开发模板引擎Jade的入门
2018/05/11 Javascript
Vue框架里使用Swiper的方法示例
2018/09/20 Javascript
关于node-bindings无法在Electron中使用的解决办法
2018/12/18 Javascript
微信小程序中的上拉、下拉菜单功能
2020/03/13 Javascript
原生JS封装拖动验证滑块的实现代码示例
2020/06/01 Javascript
Vue $emit()不能触发父组件方法的原因及解决
2020/07/28 Javascript
Python2.7 实现引入自己写的类方法
2018/04/29 Python
使用Python进行QQ批量登录的实例代码
2018/06/11 Python
Python检查和同步本地时间(北京时间)的实现方法
2018/12/03 Python
关于Python 常用获取元素 Driver 总结
2019/11/24 Python
详解Python在使用JSON时需要注意的编码问题
2019/12/06 Python
HTML5 CSS3给网站设计带来出色效果
2009/07/16 HTML / CSS
运动会通讯稿50字
2014/01/30 职场文书
医学专业应届生的自我评价
2014/02/28 职场文书
经典团队口号
2014/06/06 职场文书
销售活动策划方案
2014/08/26 职场文书
销售顾问工作计划书
2014/09/15 职场文书
后勤个人工作总结
2015/02/28 职场文书
2016圣诞节贺卡寄语
2015/12/07 职场文书
手把手教你实现PyTorch的MNIST数据集
2021/06/28 Python
避坑之 JavaScript 中的toFixed()和正则表达式
2022/04/19 Javascript