使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现正弦信号的时域波形和频谱图示例【基于matplotlib】
May 04 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
Feb 25 Python
Python 中Django安装和使用教程详解
Jul 03 Python
python采集百度搜索结果带有特定URL的链接代码实例
Aug 30 Python
Python类反射机制使用实例解析
Dec 30 Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 Python
Django更新models数据库结构步骤
Apr 01 Python
使用Python和百度语音识别生成视频字幕的实现
Apr 09 Python
浅谈Python __init__.py的作用
Oct 28 Python
python3处理word文档实例分析
Dec 01 Python
Biblibili视频投稿接口分析并以Python实现自动投稿功能
Feb 05 Python
python playwrigh框架入门安装使用
Jul 23 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
分页显示Oracle数据库记录的类之一
2006/10/09 PHP
PHP的cURL库简介及使用示例
2015/02/06 PHP
PHP 数组遍历foreach语法结构及实例
2016/06/13 PHP
PHP使用PHPExcel删除Excel单元格指定列的方法
2016/07/06 PHP
微信公众平台开发教程④ ThinkPHP框架下微信支付功能图文详解
2019/04/10 PHP
js中更短的 Array 类型转换
2011/10/30 Javascript
详解javascript new的运行机制
2016/01/26 Javascript
在AngularJS框架中处理数据建模的方式解析
2016/03/05 Javascript
jquery点击改变class并toggle的实现代码
2016/05/15 Javascript
基于JS实现textarea中获取动态剩余字数的方法
2016/05/25 Javascript
Js得到radiobuttonlist选中值的两种方法(推荐)
2016/08/25 Javascript
js鼠标按键事件和键盘按键事件用法实例汇总
2016/10/03 Javascript
JS声明式函数与赋值式函数实例分析
2016/12/13 Javascript
微信小程序ajax实现请求服务器数据及模版遍历数据功能示例
2017/12/15 Javascript
vue 2.0 购物车小球抛物线的示例代码
2018/02/01 Javascript
用ES6的class模仿Vue写一个双向绑定的示例代码
2018/04/20 Javascript
解决vue打包css文件中背景图片的路径问题
2018/09/03 Javascript
vue axios请求成功却进入catch的原因分析
2020/09/08 Javascript
python计数排序和基数排序算法实例
2014/04/25 Python
详解在Python程序中自定义异常的方法
2015/10/16 Python
Python socket模块实现的udp通信功能示例
2019/04/10 Python
Django实现网页分页功能
2019/10/31 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
2020/09/05 Python
医药大学生求职简历的自我评价
2013/10/17 职场文书
应届毕业生专业个人求职自荐信格式
2013/11/20 职场文书
机电一体化专业推荐信
2013/12/03 职场文书
兰兰过桥教学反思
2014/02/08 职场文书
战友聚会策划方案
2014/06/13 职场文书
暑期培训心得体会
2014/09/02 职场文书
2014年教师节演讲稿
2014/09/03 职场文书
通讯稿范文
2015/07/22 职场文书
如何用python插入独创性声明
2021/03/31 Python
SQL Server2019数据库之简单子查询的具有方法
2021/04/27 SQL Server
Vue如何清空对象
2022/03/03 Vue.js
SQL解决未能删除约束问题drop constraint
2022/05/30 SQL Server
MySQL外键约束(Foreign Key)案例详解
2022/06/28 MySQL