使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python多线程爬虫爬取电影天堂资源
Sep 23 Python
基于numpy中数组元素的切片复制方法
Nov 15 Python
浅谈python脚本设置运行参数的方法
Dec 03 Python
基于Python2、Python3中reload()的不同用法介绍
Aug 12 Python
python+selenium select下拉选择框定位处理方法
Aug 24 Python
Pygame的程序开始示例代码
May 07 Python
使用Numpy对特征中的异常值进行替换及条件替换方式
Jun 08 Python
python实现AdaBoost算法的示例
Oct 03 Python
python time()的实例用法
Nov 03 Python
python statsmodel的使用
Dec 21 Python
Python基础之元类详解
Apr 29 Python
总结Python连接CS2000的详细步骤
Jun 23 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
使用字符串函数输出整数化的PHP版本号
2006/10/09 PHP
用Php编写注册后Email激活验证的实例代码
2013/03/11 PHP
php计算多维数组中所有值总和的方法
2015/06/24 PHP
php的mail函数发送UTF-8编码中文邮件时标题乱码的解决办法
2015/10/20 PHP
在WordPress的文章编辑器中设置默认内容的方法
2015/12/29 PHP
thinkPHP5.0框架整体架构总览【应用,模块,MVC,驱动,行为,命名空间等】
2017/03/25 PHP
老生常谈PHP面向对象之解释器模式
2017/05/17 PHP
解决Laravel无法使用COOKIE和SESSION的问题
2019/10/16 PHP
jQuery中filter(),not(),split()使用方法
2010/07/06 Javascript
基于jquery的时间段实现代码
2012/08/02 Javascript
通过隐藏iframe实现文件下载的js方法介绍
2014/02/26 Javascript
JavaScript中的toLocaleDateString()方法使用简介
2015/06/12 Javascript
json传值以及ajax接收详解
2016/05/24 Javascript
JavaScript 中有关数组对象的方法(详解)
2016/08/15 Javascript
AngularJS的ng Http Request与response格式转换方法
2016/11/07 Javascript
swiper插件自定义切换箭头按钮
2017/12/28 Javascript
vue下拉菜单组件(含搜索)的实现代码
2018/11/25 Javascript
vue-router之解决addRoutes使用遇到的坑
2020/07/19 Javascript
js实现石头剪刀布游戏
2020/10/11 Javascript
[01:29:31]VP VS VG Supermajor小组赛胜者组第二轮 BO3第一场 6.2
2018/06/03 DOTA
Python中super()函数简介及用法分享
2016/07/11 Python
深入理解Python中变量赋值的问题
2017/01/12 Python
Python cookbook(数据结构与算法)根据字段将记录分组操作示例
2018/03/19 Python
python逐行读写txt文件的实例讲解
2018/04/03 Python
基于python绘制科赫雪花
2018/06/22 Python
基于PyQt4和PySide实现输入对话框效果
2019/02/27 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
2019/10/25 Python
python利用dlib获取人脸的68个landmark
2019/11/27 Python
Python运行DLL文件的方法
2020/01/17 Python
Python如何读写CSV文件
2020/08/13 Python
解决python打开https出现certificate verify failed的问题
2020/09/03 Python
茵宝(Umbro)英国官方商店:英国足球服装生产商
2016/12/29 全球购物
文化产业实施方案
2014/06/07 职场文书
毕业生党员个人总结
2015/02/14 职场文书
诚实守信主题班会
2015/08/13 职场文书
终止合同协议书范本
2016/03/22 职场文书