使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解使用Python处理文件目录的相关方法
Oct 16 Python
Python求算数平方根和约数的方法汇总
Mar 09 Python
Python实现递归遍历文件夹并删除文件
Apr 18 Python
Python脚本实现12306火车票查询系统
Sep 30 Python
pandas获取groupby分组里最大值所在的行方法
Apr 20 Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 Python
Python Tkinter模块 GUI 可视化实例
Nov 20 Python
使用Python求解带约束的最优化问题详解
Feb 11 Python
python 实现一个简单的线性回归案例
Dec 17 Python
python实现三种随机请求头方式
Jan 05 Python
python中封包建立过程实例
Feb 18 Python
Django使用channels + websocket打造在线聊天室
May 20 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
php简单提示框alert封装函数
2010/08/08 PHP
php HandlerSocket的使用
2011/05/02 PHP
php学习笔记之 函数声明(二)
2011/06/09 PHP
codeigniter教程之上传视频并使用ffmpeg转flv示例
2014/02/13 PHP
php延迟静态绑定实例分析
2015/02/08 PHP
又十个超级有用的PHP代码片段
2015/09/24 PHP
PHP实现唤起微信支付功能
2019/02/18 PHP
jquery异步循环获取功能实现代码
2010/09/19 Javascript
javaScript实现滚动新闻的方法
2015/07/30 Javascript
jquery中ajax跨域方法实例分析
2015/12/18 Javascript
jQuery简单验证上传文件大小及类型的方法
2016/06/02 Javascript
vuejs在解析时出现闪烁的原因及防止闪烁的方法
2016/09/19 Javascript
jQuery自定义组件(导入组件)
2016/11/08 Javascript
JS实现获取来自百度,Google,soso,sogou关键词的方法
2016/12/21 Javascript
基于JavaScript实现复选框的全选和取消全选
2017/02/09 Javascript
使用D3.js制作图表详解
2017/08/13 Javascript
angularjs通过过滤器返回超链接的方法
2018/10/26 Javascript
小程序实现选择题选择效果
2018/11/04 Javascript
利用d3.js制作连线动画图与编辑器的方法实例
2019/09/05 Javascript
JavaScript实现PC端横向轮播图
2020/02/07 Javascript
在Vue 中获取下拉框的文本及选项值操作
2020/08/13 Javascript
[01:25]2014DOTA2国际邀请赛 zhou分析LGD比赛情况
2014/07/14 DOTA
[43:03]完美世界DOTA2联赛PWL S2 PXG vs Magma 第二场 11.21
2020/11/24 DOTA
python 运用Django 开发后台接口的实例
2018/12/11 Python
Python面向对象基础入门之编码细节与注意事项
2018/12/11 Python
用python 实现在不确定行数情况下多行输入方法
2019/01/28 Python
对python实现模板生成脚本的方法详解
2019/01/30 Python
python 字符串常用函数详解
2019/09/11 Python
python 基于Apscheduler实现定时任务
2020/12/15 Python
上课迟到检讨书
2014/01/19 职场文书
文明寄语大全
2014/04/11 职场文书
经济贸易系毕业生求职信
2014/05/31 职场文书
公安机关纪律作风整顿个人剖析材料材料
2014/10/10 职场文书
死亡证明书样本说明
2014/10/18 职场文书
2016年六一儿童节开幕词
2016/03/04 职场文书
教师实习自我鉴定总结
2019/08/20 职场文书