使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之开始真正编程
Sep 12 Python
Python中使用scapy模拟数据包实现arp攻击、dns放大攻击例子
Oct 23 Python
浅谈python迭代器
Nov 08 Python
好的Python培训机构应该具备哪些条件
May 23 Python
python selenium爬取斗鱼所有直播房间信息过程详解
Aug 09 Python
python面向对象之类属性和类方法案例分析
Dec 30 Python
Python爬取YY评级分数并保存数据实现过程解析
Jun 01 Python
如何基于Python Matplotlib实现网格动画
Jul 20 Python
python实现批处理文件
Jul 28 Python
运行python提示no module named sklearn的解决方法
Nov 29 Python
python 判断文件或文件夹是否存在
Mar 18 Python
Python内置类型集合set和frozenset的使用详解
Apr 26 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
2021年最新CPU天梯图
2021/03/04 数码科技
深入PHP数据缓存的使用说明
2013/05/10 PHP
php 批量替换程序的具体实现代码
2013/10/04 PHP
php preg_replace替换实例讲解
2013/11/04 PHP
jQuery.Autocomplete实现自动完成功能(详解)
2010/07/13 Javascript
使用基于jquery的gamequery插件做JS乒乓球游戏
2011/07/31 Javascript
解决js正则匹配换行问题实现代码
2012/12/10 Javascript
判定是否原生方法的JS代码
2013/11/12 Javascript
一个js控制的导航菜单实例代码
2013/12/03 Javascript
JqueryMobile动态生成listView并实现刷新的两种方法
2014/03/05 Javascript
JavaScript sup方法入门实例(把字符串显示为上标)
2014/10/20 Javascript
可以浮动某个物体的jquery控件用法实例
2015/07/24 Javascript
jquery实现鼠标滑过后动态图片提示效果实例
2015/08/10 Javascript
NodeJS使用formidable实现文件上传
2016/10/27 NodeJs
js实现微博发布小功能
2017/01/12 Javascript
Javascript DOM事件操作小结(监听鼠标点击、释放,悬停、离开等)
2017/01/20 Javascript
微信小程序网络请求的封装与填坑之路
2017/04/01 Javascript
node.js使用http模块创建服务器和客户端完整示例
2020/02/10 Javascript
toString.call()通用的判断数据类型方法示例
2020/08/28 Javascript
js实现随机圆与矩形功能
2020/10/29 Javascript
Python常用正则表达式符号浅析
2014/08/13 Python
Python实现的Google IP 可用性检测脚本
2015/04/23 Python
Python NumPy库安装使用笔记
2015/05/18 Python
Python+OpenCV+图片旋转并用原底色填充新四角的例子
2019/12/12 Python
pytorch 模拟关系拟合——回归实例
2020/01/14 Python
出国导师推荐信
2014/01/16 职场文书
2014信息技术专业毕业生自我评价
2014/01/17 职场文书
师德个人剖析材料
2014/02/02 职场文书
电焊工岗位职责
2014/03/06 职场文书
社会实践活动报告
2015/02/05 职场文书
收银员岗位职责范本
2015/04/07 职场文书
公司周年庆寄语
2019/06/21 职场文书
Jupyter notebook 更改文件打开的默认路径操作
2021/05/21 Python
经典《舰娘》游改全新动画预告 预定11月开播
2022/04/01 日漫
Android RecyclerView实现九宫格效果
2022/06/28 Java/Android
DQL数据查询语句使用示例
2022/12/24 MySQL