使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python数据分析之双色球基于线性回归算法预测下期中奖结果示例
Feb 08 Python
用Eclipse写python程序
Feb 10 Python
python3.6.3转化为win-exe文件发布的方法
Oct 31 Python
selenium处理元素定位点击无效问题
Jun 12 Python
python向字符串中添加元素的实例方法
Jun 28 Python
对python3中的RE(正则表达式)-详细总结
Jul 23 Python
Django中使用session保持用户登陆连接的例子
Aug 06 Python
Python学习笔记之While循环用法分析
Aug 14 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 Python
Python实现简单的猜单词小游戏
Oct 28 Python
python神经网络 使用Keras构建RNN训练
May 04 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
基于数据库的在线人数,日访问量等统计
2006/10/09 PHP
通用PHP动态生成静态HTML网页的代码
2010/03/04 PHP
PHP验证码类代码( 最新修改,完全定制化! )
2010/12/02 PHP
PHP正则表达式替换站点关键字链接后空白的解决方法
2014/09/16 PHP
asp.net和php的区别点总结
2019/10/10 PHP
利用javascript/jquery对上传文件格式过滤的方法
2009/07/25 Javascript
基于JQuery框架的AJAX实例代码
2009/11/03 Javascript
javascript 基础篇3 类,回调函数,内置对象,事件处理
2012/03/14 Javascript
JS中获取数据库中的值的方法
2013/07/14 Javascript
jquery触发a标签跳转事件示例代码
2013/07/21 Javascript
浅谈jQuery中 wrap() wrapAll() 与 wrapInner()的差异
2014/11/12 Javascript
禁止按回车键提交表单的方法
2015/06/11 Javascript
jQuery实现带延迟的二级tab切换下拉列表效果
2015/09/01 Javascript
关于Vue.js一些问题和思考学习笔记(1)
2016/12/02 Javascript
前端页面文件拖拽上传模块js代码示例
2017/05/19 Javascript
详解vue mint-ui源码解析之loadmore组件
2017/10/11 Javascript
jQuery HTML设置内容和属性操作实例分析
2020/05/20 jQuery
vue 使用localstorage实现面包屑的操作
2020/11/16 Javascript
[53:52]OG vs EG 2018国际邀请赛淘汰赛BO3 第二场 8.23
2018/08/24 DOTA
Python中将dataframe转换为字典的实例
2018/04/13 Python
django利用request id便于定位及给日志加上request_id
2018/08/26 Python
Python自定义函数计算给定日期是该年第几天的方法示例
2019/05/30 Python
python解析xml简单示例
2019/06/21 Python
Python画图高斯分布的示例
2019/07/10 Python
细数nn.BCELoss与nn.CrossEntropyLoss的区别
2020/02/29 Python
Python yield生成器和return对比代码实例
2020/04/20 Python
Python安装第三方库攻略(pip和Anaconda)
2020/10/15 Python
京东奢侈品:全球奢侈品牌
2018/03/17 全球购物
倩碧澳大利亚官网:Clinique澳大利亚
2019/07/22 全球购物
Everlast官网:拳击、综合格斗和健身相关的体育用品
2020/08/03 全球购物
通信工程专业个人找工作求职信范文
2013/09/21 职场文书
仓库管理制度
2014/01/21 职场文书
创业者如何撰写出一份打动投资人的商业计划书?
2019/07/02 职场文书
AJAX学习笔记
2021/05/18 Javascript
Python制作表白爱心合集
2022/01/22 Python
Python可视化神器pyecharts之绘制地理图表练习
2022/07/07 Python