python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中映射类型的内建函数和工厂函数
Aug 19 Python
windows下ipython的安装与使用详解
Oct 20 Python
Python 多线程实例详解
Mar 25 Python
python距离测量的方法
Mar 06 Python
python+ffmpeg批量去视频开头的方法
Jan 09 Python
python单例模式的多种实现方法
Jul 26 Python
python被修饰的函数消失问题解决(基于wraps函数)
Nov 04 Python
python GUI库图形界面开发之PyQt5布局控件QVBoxLayout详细使用方法与实例
Mar 06 Python
Jupyter notebook如何修改平台字体
May 13 Python
Python爬虫小例子——爬取51job发布的工作职位
Jul 10 Python
用Python提取PDF表格的方法
Apr 11 Python
手把手教你使用TensorFlow2实现RNN
Jul 15 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
PHP 生成N个不重复的随机数
2015/01/21 PHP
PHP如何实现Unicode和Utf-8编码相互转换
2015/07/29 PHP
PHP结合jQuery插件ajaxFileUpload实现异步上传文件实例
2020/08/17 PHP
php5.6.x到php7.0.x特性小结
2019/08/17 PHP
Thinkphp页面跳转设置跳转等待时间的操作
2019/10/16 PHP
Highslide.js是一款基于js实现的网页中图片展示插件
2020/03/30 Javascript
JavaScript中去掉数组中的重复值的实现方法
2011/08/03 Javascript
js 控制下拉菜单刷新的方法
2013/03/03 Javascript
浅析LigerUi开发中谨慎载入common.css文件
2013/07/09 Javascript
JS实现将人民币金额转换为大写的示例代码
2014/02/13 Javascript
在页面加载完成后通过jquery给多个span赋值
2014/05/21 Javascript
在linux中使用包管理器安装node.js
2015/03/13 Javascript
javascript图片预加载实例分析
2015/07/16 Javascript
JS实现转动随机数抽奖特效代码
2020/04/16 Javascript
深入理解JavaScript中的call、apply、bind方法的区别
2016/05/30 Javascript
ros::spin() 和 ros::spinOnce()函数的区别及详解
2016/10/01 Javascript
基于JavaScript实现的顺序查找算法示例
2017/04/14 Javascript
jQuery 禁止表单用户名、密码自动填充功能
2017/10/30 jQuery
ReactNative实现Toast的示例
2017/12/31 Javascript
微信小程序动态生成二维码的实现代码
2018/07/25 Javascript
Vue瀑布流插件的使用示例
2018/09/19 Javascript
新手快速入门JavaScript装饰者模式与AOP
2019/06/24 Javascript
React+TypeScript+webpack4多入口配置详解
2019/08/08 Javascript
使用JavaScrip模拟实现仿京东搜索框功能
2019/10/16 Javascript
Javascript异步编程async实现过程详解
2020/04/02 Javascript
vue使用video插件vue-video-player详解
2020/10/23 Javascript
nodejs中使用worker_threads来创建新的线程的方法
2021/01/22 NodeJs
[48:56]2018DOTA2亚洲邀请赛 3.31 小组赛 A组 VG vs KG
2018/03/31 DOTA
python连接MySQL、MongoDB、Redis、memcache等数据库的方法
2013/11/15 Python
详解python中init方法和随机数方法
2019/03/13 Python
Python 计算任意两向量之间的夹角方法
2019/07/05 Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
2020/05/26 Python
Python2与Python3关于字符串编码处理的差别总结
2020/09/07 Python
女性健康知识讲座通知
2015/04/23 职场文书
小学教学工作总结2015
2015/05/13 职场文书
mysql字符串截取函数小结
2021/04/05 MySQL