python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用multiprocessing模块实现带回调函数的异步调用方法
Apr 18 Python
Python找出list中最常出现元素的方法
Jun 14 Python
详解如何利用Cython为Python代码加速
Jan 27 Python
python email smtplib模块发送邮件代码实例
Apr 26 Python
Python实现的文本对比报告生成工具示例
May 22 Python
Form表单及django的form表单的补充
Jul 25 Python
python代码实现TSNE降维数据可视化教程
Feb 28 Python
基于python实现模拟数据结构模型
Jun 12 Python
python中plt.imshow与cv2.imshow显示颜色问题
Jul 16 Python
python多线程semaphore实现线程数控制的示例
Aug 10 Python
Django开发RESTful API实现增删改查(入门级)
May 10 Python
Python3 如何开启自带http服务
May 18 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
phpMyAdmin 链接表的附加功能尚未激活的问题
2010/08/01 PHP
web站点获取用户IP的安全方法 HTTP_X_FORWARDED_FOR检验
2013/06/01 PHP
解析PHP中的file_get_contents获取远程页面乱码的问题
2013/06/25 PHP
php命名空间学习详解
2014/02/27 PHP
轻松创建nodejs服务器(1):一个简单nodejs服务器例子
2014/12/18 NodeJs
jQuery中prevAll()方法用法实例
2015/01/08 Javascript
深入理解JavaScript系列(34):设计模式之命令模式详解
2015/03/03 Javascript
jquery超简单实现手风琴效果的方法
2015/06/05 Javascript
js带缩略图的图片轮播效果代码分享
2015/09/14 Javascript
Webpack 实现 AngularJS 的延迟加载
2016/03/02 Javascript
JavaScript操作表单实例讲解(上)
2016/06/20 Javascript
JavaScript实现多栏目切换效果
2016/12/12 Javascript
vue v-model实现自定义样式多选与单选功能
2018/07/05 Javascript
小程序:授权、登录、session_key、unionId的详解
2019/05/15 Javascript
详解javascript中var与ES6规范中let、const区别与用法
2020/01/11 Javascript
小程序开发之模态框组件封装
2020/04/23 Javascript
使用简单工厂模式来进行Python的设计模式编程
2016/03/01 Python
Python Web框架Tornado运行和部署
2020/10/19 Python
Python分支结构(switch)操作简介
2018/01/17 Python
Python unittest模块用法实例分析
2018/05/25 Python
Windows下Anaconda2安装NLTK教程
2018/09/19 Python
详解Python网络框架Django和Scrapy安装指南
2019/04/01 Python
python中字符串数组逆序排列方法总结
2019/06/23 Python
python删除文件夹下相同文件和无法打开的图片
2019/07/16 Python
利用python中集合的唯一性实现去重
2020/02/11 Python
浅谈pycharm导入pandas包遇到的问题及解决
2020/06/01 Python
PyCharm中关于安装第三方包的三个建议
2020/09/17 Python
Python Selenium库的基本使用教程
2021/01/04 Python
纯css3制作网站后台管理面板
2014/12/30 HTML / CSS
上海某公司.net方向笔试题
2014/09/14 面试题
《风筝》教学反思
2014/04/10 职场文书
2014年财务工作总结范文
2014/11/11 职场文书
2014年检察院个人工作总结
2014/12/09 职场文书
业务员岗位职责范本
2015/04/03 职场文书
再读《皇帝的新衣》的读后感悟!
2019/08/07 职场文书
GoFrame gredis缓存DoVar Conn连接对象 自动序列化GoFrame gredisDo/DoVar方法Conn连接对象自动序列化/反序列化总结
2022/06/14 Golang