使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析json实例方法
Nov 19 Python
Python群发邮件实例代码
Jan 03 Python
python实现的二叉树定义与遍历算法实例
Jun 30 Python
Python3 获取一大段文本之间两个关键字之间的内容方法
Oct 11 Python
python浪漫表白源码
Apr 05 Python
一行Python代码过滤标点符号等特殊字符
Aug 12 Python
Python实现微信中找回好友、群聊用户撤回的消息功能示例
Aug 23 Python
解决Python计算矩阵乘向量,矩阵乘实数的一些小错误
Aug 26 Python
关于Flask项目无法使用公网IP访问的解决方式
Nov 19 Python
python两个_多个字典合并相加的实例代码
Dec 26 Python
使用python绘制cdf的多种实现方法
Feb 25 Python
利用python做数据拟合详情
Nov 17 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
深入apache配置文件httpd.conf的部分参数说明
2013/06/28 PHP
PHP高精确度运算BC函数库实例详解
2017/08/15 PHP
PHP+mysql实现的三级联动菜单功能示例
2019/02/15 PHP
Maps Javascript
2007/01/22 Javascript
了不起的node.js读书笔记之node的学习总结
2014/12/22 Javascript
JS实现在网页中弹出一个输入框的方法
2015/03/03 Javascript
jQuery给动态添加的元素绑定事件的方法
2015/03/09 Javascript
node.js require() 源码解读
2015/12/13 Javascript
JS数组去掉重复数据只保留一条的实现代码
2016/08/11 Javascript
JavaScript中利用for循环遍历数组
2017/01/15 Javascript
jQuery给表格添加分页效果
2017/03/02 Javascript
JS实现的Unicode编码转换操作示例
2017/04/28 Javascript
node中Express 动态设置端口的方法
2017/08/04 Javascript
Vue组件通信的四种方式汇总
2018/02/08 Javascript
vue.js,ajax渲染页面的实例
2018/02/11 Javascript
Vue路由守卫之路由独享守卫
2019/09/25 Javascript
PyQt5每天必学之QSplitter实现窗口分隔
2018/04/19 Python
python 实现在txt指定行追加文本的方法
2018/04/29 Python
Opencv-Python图像透视变换cv2.warpPerspective的示例
2019/04/11 Python
python pytest进阶之xunit fixture详解
2019/06/27 Python
Python+selenium点击网页上指定坐标的实例
2019/07/05 Python
pandas的排序和排名的具体使用
2019/07/31 Python
Python中断多重循环的思路总结
2019/10/04 Python
Python3 sys.argv[ ]用法详解
2019/10/24 Python
Pytorch十九种损失函数的使用详解
2020/04/29 Python
美国高级工作服品牌:Carhartt
2018/01/25 全球购物
美国男士和女士奢侈品折扣手表购物网站:Certified Watch Store
2018/06/13 全球购物
奥地利时尚、美容、玩具和家居之家:Kastner & Öhler
2020/04/26 全球购物
什么是规则表达式
2012/05/03 面试题
电子专业求职信
2014/06/19 职场文书
社团活动总结怎么写
2014/06/30 职场文书
2014小学教师个人工作总结
2014/11/10 职场文书
2016年大学生社区服务活动总结
2016/04/06 职场文书
SpringCloud Function SpEL注入漏洞分析及环境搭建
2022/04/08 Java/Android
《勇者辞职不干了》上卷BD发售宣传CM公开
2022/04/08 日漫
Python可视化神器pyecharts绘制地理图表
2022/07/07 Python