使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python深入学习之特殊方法与多范式
Aug 31 Python
详解在Python程序中自定义异常的方法
Oct 16 Python
Python使用中文正则表达式匹配指定中文字符串的方法示例
Jan 20 Python
Python 获得命令行参数的方法(推荐)
Jan 24 Python
详解Python装饰器
Mar 25 Python
使用Matplotlib 绘制精美的数学图形例子
Dec 13 Python
Python使用type动态创建类操作示例
Feb 29 Python
jupyter note 实现将数据保存为word
Apr 14 Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 Python
Python定时任务APScheduler安装及使用解析
Aug 07 Python
Python获取android设备cpu和内存占用情况
Nov 15 Python
使用豆瓣源来安装python中的第三方库方法
Jan 26 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
百度工程师讲PHP函数的实现原理及性能分析(三)
2015/05/13 PHP
PHP中16个高危函数整理
2019/09/19 PHP
js利用div背景,做一个竖线的效果。
2008/11/22 Javascript
Javascript 中 null、NaN和undefined的区别总结
2013/04/10 Javascript
Chrome扩展页面动态绑定JS事件提示错误
2014/02/11 Javascript
js自动查找select下拉的菜单并选择(示例代码)
2014/02/26 Javascript
javascript 中that的含义示例介绍
2014/05/14 Javascript
js实现图片和链接文字同步切换特效的方法
2015/02/20 Javascript
Javascript中判断对象是否为空
2015/06/10 Javascript
基于jquery实现省市联动特效
2015/12/17 Javascript
AngularJs $parse、$eval和$observe、$watch详解
2016/09/21 Javascript
jQuery动态添加与删除tr行实例代码
2016/10/18 Javascript
微信小程序 swiper组件轮播图详解及实例
2016/11/16 Javascript
JS定时器实现数值从0到10来回变化
2016/12/09 Javascript
详解Vue监听数据变化原理
2017/03/08 Javascript
浅析webpack 如何优雅的使用tree-shaking(摇树优化)
2017/08/16 Javascript
浅谈React Native Flexbox布局(小结)
2018/01/08 Javascript
js+html5 canvas实现ps钢笔抠图
2019/04/28 Javascript
详解微信小程序实现跑马灯效果(附完整代码)
2019/04/29 Javascript
举例讲解如何在Python编程中进行迭代和遍历
2016/01/19 Python
使用python实现省市三级菜单效果
2016/01/20 Python
Python字符编码判断方法分析
2016/07/01 Python
python scipy卷积运算的实现方法
2019/09/16 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
2021/01/15 Python
pytorch下的unsqueeze和squeeze的用法说明
2021/02/06 Python
css3的@media属性实现页面响应式布局示例代码
2014/02/10 HTML / CSS
接受捐赠答谢词
2014/01/27 职场文书
大学生求职计划书
2014/04/30 职场文书
环保倡议书怎么写
2014/05/16 职场文书
销售提升方案
2014/06/07 职场文书
2014村书记党建工作汇报材料
2014/11/02 职场文书
读后感作文评语
2014/12/25 职场文书
好员工观后感
2015/06/17 职场文书
入党后的感想
2015/08/10 职场文书
MySQL令人咋舌的隐式转换
2021/04/05 MySQL
Javascript 解构赋值详情
2021/11/17 Javascript