使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的socket模块源码中的一些实现要点分析
Jun 06 Python
python中如何使用朴素贝叶斯算法
Apr 06 Python
python代码实现ID3决策树算法
Dec 20 Python
详解django中url路由配置及渲染方式
Feb 25 Python
使用Python将Mysql的查询数据导出到文件的方法
Feb 25 Python
PyQt5下拉式复选框QComboCheckBox的实例
Jun 25 Python
Python脚本利用adb进行手机控制的方法
Jul 08 Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 Python
python实现简单颜色识别程序
Feb 19 Python
python给list排序的简单方法
Dec 10 Python
Python基础进阶之海量表情包多线程爬虫功能的实现
Dec 17 Python
Selenium浏览器自动化如何上传文件
Apr 06 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
无数据库的详细域名查询程序PHP版(3)
2006/10/09 PHP
Bo-Blog专用的给Windows服务器的IIS Rewrite程序
2007/08/26 PHP
require(),include(),require_once()和include_once()区别
2008/03/27 PHP
php使用正则表达式提取字符串中尖括号、小括号、中括号、大括号中的字符串
2020/04/05 PHP
php实现的Cookies操作类实例
2014/09/24 PHP
laravel 5 实现模板主题功能(续)
2015/03/02 PHP
PHP使用fopen与file_get_contents读取文件实例分享
2016/03/04 PHP
Laravel执行migrate命令提示:No such file or directory的解决方法
2016/03/16 PHP
php封装的smartyBC类完整实例
2016/10/19 PHP
Laravel关系模型指定条件查询方法
2019/10/10 PHP
js批量设置样式的三种方法不推荐使用with
2013/02/25 Javascript
NodeJS学习笔记之Connect中间件模块(二)
2015/01/27 NodeJs
JavaScript中输出信息的方法(信息确认框-提示输入框-文档流输出)
2016/06/12 Javascript
js中获取jsp表单中radio类型的值简单实例
2016/08/15 Javascript
Vue2.0 http请求以及loading展示实例
2018/03/06 Javascript
JavaScript 扩展运算符用法实例小结【基于ES6】
2019/06/17 Javascript
elementui之el-tebs浏览器卡死的问题和使用报错未注册问题
2019/07/06 Javascript
Javascript操作select控件代码实例
2020/02/14 Javascript
在vue中使用console.log无效的解决
2020/08/09 Javascript
[02:03]完美世界DOTA2联赛10月30日赛事集锦
2020/10/31 DOTA
opencv实现静态手势识别 opencv实现剪刀石头布游戏
2019/01/22 Python
详解python tkinter教程-事件绑定
2019/03/28 Python
Python自动抢红包教程详解
2019/06/11 Python
深入浅析Python 中的sklearn模型选择
2019/10/12 Python
基于python使用tibco ems代码实例
2019/12/20 Python
python实现猜拳游戏
2020/03/04 Python
特罗佩亚包官方网站:Tropea
2017/01/03 全球购物
瑜伽服装品牌:露露柠檬(lululemon athletica)
2017/06/04 全球购物
英国马莎百货印度官网:Marks & Spencer印度
2020/10/08 全球购物
文艺晚会主持词
2014/03/24 职场文书
大学生英语演讲稿
2014/04/24 职场文书
英文演讲稿
2014/05/15 职场文书
导游词之金鞭溪风景区
2019/09/12 职场文书
导游词之河北野三坡
2019/12/11 职场文书
MySQL的全局锁和表级锁的具体使用
2021/08/23 MySQL
分布式架构Redis中有哪些数据结构及底层实现原理
2022/03/13 Redis