使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 创建子进程模块subprocess详解
Apr 08 Python
详解Python3中yield生成器的用法
Aug 20 Python
Python 查看文件的读写权限方法
Jan 23 Python
python 给DataFrame增加index行名和columns列名的实现方法
Jun 08 Python
python数据批量写入ScrolledText的优化方法
Oct 11 Python
python并发和异步编程实例
Nov 15 Python
Django保护敏感信息的方法示例
May 09 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 Python
python 求定积分和不定积分示例
Nov 20 Python
Python StringIO如何在内存中读写str
Jan 07 Python
C3 线性化算法与 MRO之Python中的多继承
Oct 05 Python
Python基本知识点总结
Apr 07 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
写一个用户在线显示的程序
2006/10/09 PHP
php 中英文语言转换类
2011/09/07 PHP
PHP的PDO操作简单示例
2016/03/30 PHP
PHP延迟静态绑定的深入讲解
2018/04/02 PHP
CI框架(CodeIgniter)实现的数据库增删改查操作总结
2018/05/23 PHP
yii框架结合charjs统计上一年与当前年数据的方法示例
2020/04/04 PHP
PHP常用字符串函数用法实例总结
2020/06/04 PHP
jQuery EasyUI中对表格进行编辑的实现代码
2010/06/10 Javascript
JQuery UI DatePicker中z-index默认为1的解决办法
2010/09/28 Javascript
JavaScript 盒模型 尺寸深入理解
2012/12/31 Javascript
解决JS浮点数运算出现Bug的方法
2013/03/12 Javascript
JavaScript通过RegExp实现客户端验证处理程序
2013/05/07 Javascript
jquery根据锚点offset值实现动画切换
2014/09/11 Javascript
jQuery中offsetParent()方法用法实例
2015/01/19 Javascript
jQuery结合HTML5制作的爱心树表白动画
2015/02/01 Javascript
JS实现文件动态顺序载入的方法
2015/03/07 Javascript
jQuery实现的fixedMenu下拉菜单效果代码
2015/08/24 Javascript
最简单纯JavaScript实现Tab标签页切换的方式(推荐)
2016/07/25 Javascript
jQuery自制提示框tooltip改进版
2016/08/01 Javascript
原生js实现查询天气小应用
2016/12/09 Javascript
JS实现浏览器打印、打印预览示例
2017/02/28 Javascript
addEventListener()与removeEventListener()解析
2017/04/20 Javascript
详解angularJs模块ui-router之状态嵌套和视图嵌套
2017/04/28 Javascript
vue.js实现的经典计算器/科学计算器功能示例
2018/07/11 Javascript
使用 UniApp 实现小程序的微信登录功能
2020/06/09 Javascript
[49:56]VG vs Optic 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
Python打印输出数组中全部元素
2018/03/13 Python
python Dijkstra算法实现最短路径问题的方法
2019/09/19 Python
django框架中间件原理与用法详解
2019/12/10 Python
pytorch 彩色图像转灰度图像实例
2020/01/13 Python
Python多线程获取返回值代码实例
2020/02/17 Python
Python下载的11种姿势(小结)
2020/11/18 Python
详解CSS3阴影 box-shadow的使用和技巧总结
2016/12/03 HTML / CSS
娇韵诗Clarins意大利官方网站:法国天然护肤品牌
2020/03/11 全球购物
2015年财务人员工作总结
2015/04/10 职场文书
让世界充满爱观后感
2015/06/10 职场文书