python sklearn常用分类算法模型的调用


Posted in Python onOctober 16, 2019

本文实例为大家分享了python sklearn分类算法模型调用的具体代码,供大家参考,具体内容如下

实现对'NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'模型的简单调用。

# coding=gbk
 
import time 
from sklearn import metrics 
import pickle as pickle 
import pandas as pd
 
 
# Multinomial Naive Bayes Classifier 
def naive_bayes_classifier(train_x, train_y): 
  from sklearn.naive_bayes import MultinomialNB 
  model = MultinomialNB(alpha=0.01) 
  model.fit(train_x, train_y) 
  return model 
 
 
# KNN Classifier 
def knn_classifier(train_x, train_y): 
  from sklearn.neighbors import KNeighborsClassifier 
  model = KNeighborsClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# Logistic Regression Classifier 
def logistic_regression_classifier(train_x, train_y): 
  from sklearn.linear_model import LogisticRegression 
  model = LogisticRegression(penalty='l2') 
  model.fit(train_x, train_y) 
  return model 
 
 
# Random Forest Classifier 
def random_forest_classifier(train_x, train_y): 
  from sklearn.ensemble import RandomForestClassifier 
  model = RandomForestClassifier(n_estimators=8) 
  model.fit(train_x, train_y) 
  return model 
 
 
# Decision Tree Classifier 
def decision_tree_classifier(train_x, train_y): 
  from sklearn import tree 
  model = tree.DecisionTreeClassifier() 
  model.fit(train_x, train_y) 
  return model 
 
 
# GBDT(Gradient Boosting Decision Tree) Classifier 
def gradient_boosting_classifier(train_x, train_y): 
  from sklearn.ensemble import GradientBoostingClassifier 
  model = GradientBoostingClassifier(n_estimators=200) 
  model.fit(train_x, train_y) 
  return model 
 
 
# SVM Classifier 
def svm_classifier(train_x, train_y): 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
# SVM Classifier using cross validation 
def svm_cross_validation(train_x, train_y): 
  from sklearn.grid_search import GridSearchCV 
  from sklearn.svm import SVC 
  model = SVC(kernel='rbf', probability=True) 
  param_grid = {'C': [1e-3, 1e-2, 1e-1, 1, 10, 100, 1000], 'gamma': [0.001, 0.0001]} 
  grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1) 
  grid_search.fit(train_x, train_y) 
  best_parameters = grid_search.best_estimator_.get_params() 
  for para, val in list(best_parameters.items()): 
    print(para, val) 
  model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True) 
  model.fit(train_x, train_y) 
  return model 
 
def read_data(data_file): 
  data = pd.read_csv(data_file)
  train = data[:int(len(data)*0.9)]
  test = data[int(len(data)*0.9):]
  train_y = train.label
  train_x = train.drop('label', axis=1)
  test_y = test.label
  test_x = test.drop('label', axis=1)
  return train_x, train_y, test_x, test_y
   
if __name__ == '__main__': 
  data_file = "H:\\Research\\data\\trainCG.csv" 
  thresh = 0.5 
  model_save_file = None 
  model_save = {} 
  
  test_classifiers = ['NB', 'KNN', 'LR', 'RF', 'DT', 'SVM','SVMCV', 'GBDT'] 
  classifiers = {'NB':naive_bayes_classifier,  
         'KNN':knn_classifier, 
          'LR':logistic_regression_classifier, 
          'RF':random_forest_classifier, 
          'DT':decision_tree_classifier, 
         'SVM':svm_classifier, 
        'SVMCV':svm_cross_validation, 
         'GBDT':gradient_boosting_classifier 
  } 
   
  print('reading training and testing data...') 
  train_x, train_y, test_x, test_y = read_data(data_file) 
   
  for classifier in test_classifiers: 
    print('******************* %s ********************' % classifier) 
    start_time = time.time() 
    model = classifiers[classifier](train_x, train_y) 
    print('training took %fs!' % (time.time() - start_time)) 
    predict = model.predict(test_x) 
    if model_save_file != None: 
      model_save[classifier] = model 
    precision = metrics.precision_score(test_y, predict) 
    recall = metrics.recall_score(test_y, predict) 
    print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall)) 
    accuracy = metrics.accuracy_score(test_y, predict) 
    print('accuracy: %.2f%%' % (100 * accuracy))  
 
  if model_save_file != None: 
    pickle.dump(model_save, open(model_save_file, 'wb'))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下的Mysql模块MySQLdb安装详解
Apr 09 Python
pycharm 使用心得(二)设置字体大小
Jun 05 Python
python基础知识小结之集合
Nov 25 Python
Python多线程经典问题之乘客做公交车算法实例
Mar 22 Python
CentOS 6.5下安装Python 3.5.2(与Python2并存)
Jun 05 Python
Python3 中把txt数据文件读入到矩阵中的方法
Apr 27 Python
python中不能连接超时的问题及解决方法
Jun 10 Python
使用pandas批量处理矢量化字符串的实例讲解
Jul 10 Python
[原创]Python入门教程3. 列表基本操作【定义、运算、常用函数】
Oct 30 Python
Python提取转移文件夹内所有.jpg文件并查看每一帧的方法
Jun 27 Python
Pytorch 实现权重初始化
Dec 31 Python
学python最电脑配置有要求么
Jul 05 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 #Python
使用python实现kNN分类算法
Oct 16 #Python
Python生成验证码、计算具体日期是一年中的第几天实例代码详解
Oct 16 #Python
python可视化实现KNN算法
Oct 16 #Python
python实现KNN分类算法
Oct 16 #Python
python子线程退出及线程退出控制的代码
Oct 16 #Python
python Pillow图像处理方法汇总
Oct 16 #Python
You might like
PHP中使用数组实现堆栈数据结构的代码
2012/02/05 PHP
有关phpmailer的详细介绍及使用方法
2013/01/28 PHP
php找出指定范围内回文数且平方根也是回文数的方法
2015/03/23 PHP
PHP基于phpqrcode生成带LOGO图像的二维码实例
2015/07/10 PHP
解析WordPress中控制用户登陆和判断用户登陆的PHP函数
2016/03/01 PHP
TP5框架请求响应参数实例分析
2019/10/17 PHP
用javascript连接access数据库的方法
2006/11/17 Javascript
ExtJS 工具栏 分页事件参数
2010/03/05 Javascript
jQuery动态添加的元素绑定事件处理函数代码
2011/08/02 Javascript
document.documentElement的一些使用技巧
2013/04/18 Javascript
javascript每日必学之基础入门
2016/02/16 Javascript
jQuery插件实现文字无缝向上滚动效果代码
2016/02/25 Javascript
Web Uploader文件上传插件使用详解
2016/05/10 Javascript
JQuery控制图片由中心点逐渐放大效果
2016/06/26 Javascript
深入对Vue.js $watch方法的理解
2017/03/20 Javascript
Bootstrap 表单验证formValidation 实现远程验证功能
2017/05/17 Javascript
JavaScript实现二叉树定义、遍历及查找的方法详解
2017/12/20 Javascript
vue组件数据传递、父子组件数据获取,slot,router路由功能示例
2019/03/19 Javascript
如何自动化部署项目?折腾服务器之旅~
2019/04/16 Javascript
[52:41]OG vs IG 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/20 DOTA
[01:07:46]完美世界DOTA2联赛循环赛 Magma vs IO BO2第二场 11.01
2020/11/02 DOTA
使用python编写脚本获取手机当前应用apk的信息
2014/07/21 Python
python技能之数据导出excel的实例代码
2017/08/11 Python
Python之列表的插入&替换修改方法
2018/06/28 Python
基于python实现从尾到头打印链表
2019/11/02 Python
基于python 凸包问题的解决
2020/04/16 Python
Python文件夹批处理操作代码实例
2020/07/21 Python
Pytest单元测试框架如何实现参数化
2020/09/05 Python
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
英国二手iPhone、音乐、电影和游戏商店:musicMagpie
2018/10/26 全球购物
365 Tickets英国:全球景点门票
2019/07/06 全球购物
社区学雷锋活动策划方案
2014/01/30 职场文书
运动会拉拉队口号
2014/06/09 职场文书
检察院院长群众路线教育实践活动个人整改措施
2014/10/04 职场文书
教师作风整顿个人剖析材料
2014/10/10 职场文书
MySQL脏读,幻读和不可重复读
2022/05/11 MySQL