使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中循环语句while用法实例
May 16 Python
关于Python面向对象编程的知识点总结
Feb 14 Python
详解Golang 与python中的字符串反转
Jul 21 Python
Python实现的中国剩余定理算法示例
Aug 05 Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 Python
浅谈python 中类属性共享的问题
Jul 02 Python
python 使用while写猜年龄小游戏过程解析
Oct 07 Python
Python socket模块ftp传输文件过程解析
Nov 05 Python
pycharm 2018 激活码及破解补丁激活方式
Sep 21 Python
Python基于callable函数检测对象是否可被调用
Oct 16 Python
Python机器学习应用之基于线性判别模型的分类篇详解
Jan 18 Python
用Python可视化新冠疫情数据
Jan 18 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
php实现httpRequest的方法
2015/03/13 PHP
php中smarty模板条件判断用法实例
2015/06/11 PHP
调试WordPress中定时任务的相关PHP脚本示例
2015/12/10 PHP
PHP与jquery实时显示网站在线人数实例详解
2016/12/02 PHP
PHP基于XMLWriter操作xml的方法分析
2017/07/17 PHP
laravel 框架配置404等异常页面
2019/01/07 PHP
PHP单例模式数据库连接类与页面静态化实现方法
2019/03/20 PHP
一个JS小玩意 几个属性相加不能超过一个特定值.
2009/09/29 Javascript
js 多种变量定义(对象直接量,数组直接量和函数直接量)
2010/05/24 Javascript
javascript中alert()与console.log()的区别
2015/08/26 Javascript
js数组的五种迭代方法及两种归并方法(推荐)
2016/06/14 Javascript
Bootstrap基本样式学习笔记之图片(6)
2016/12/07 Javascript
详解vue-validator(vue验证器)
2017/01/16 Javascript
微信小程序之网络请求简单封装实例详解
2017/06/28 Javascript
微信小程序实现左右列表联动
2020/05/19 Javascript
Js通过AES加密后PHP用Openssl解密的方法
2019/07/12 Javascript
jQuery表单选择器用法详解
2019/08/22 jQuery
JS获取当前时间戳方法解析
2020/08/29 Javascript
Python常用的日期时间处理方法示例
2015/02/08 Python
解决pyqt中ui编译成窗体.py中文乱码的问题
2016/12/23 Python
Python使用pip安装pySerial串口通讯模块
2018/04/20 Python
Python实现简易过滤删除数字的方法小结
2019/01/09 Python
python+selenium实现简历自动刷新的示例代码
2019/05/20 Python
浅析Django中关于session的使用
2019/12/30 Python
Python查找不限层级Json数据中某个key或者value的路径方式
2020/02/27 Python
德国体育用品网上商店:SC24.com
2016/08/01 全球购物
GUESS西班牙官方网上商城:美国服饰品牌
2017/03/15 全球购物
西部世纪面试题
2014/12/05 面试题
vue 中 get / delete 传递数组参数方法
2021/03/23 Vue.js
在求职信中如何凸显个人优势
2013/10/30 职场文书
竞聘副主任科员演讲稿
2014/01/11 职场文书
2014大四本科生自我鉴定总结
2014/10/04 职场文书
华清池导游词
2015/02/02 职场文书
Vue.js 带下拉选项的输入框(Textbox with Dropdown)组件
2021/04/17 Vue.js
Python利用folium实现地图可视化
2021/05/23 Python
浅谈MySQL中的六种日志
2022/03/23 MySQL